Commit 5a683756 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Re-format kernel

parent b618806b
...@@ -24,996 +24,1116 @@ ...@@ -24,996 +24,1116 @@
#include "attention/dtype_fp8.cuh" #include "attention/dtype_fp8.cuh"
#include "quantization/fp8/amd/quant_utils.cuh" #include "quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ #if defined(__HIPCC__) && \
defined(__gfx941__) || defined(__gfx942__)) (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__ #define __HIP__MI300_MI250__
#endif #endif
#if defined(NDEBUG) #if defined(NDEBUG)
#undef NDEBUG #undef NDEBUG
#include <assert.h> #include <assert.h>
#define UNREACHABLE_CODE assert(false); #define UNREACHABLE_CODE assert(false);
#define NDEBUG #define NDEBUG
#else #else
#define UNREACHABLE_CODE assert(false); #define UNREACHABLE_CODE assert(false);
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float16x4 = using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
__attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
typedef float16x4 _Half4; typedef float16x4 _Half4;
typedef struct _Half8 { typedef struct _Half8
_Half4 xy[2]; {
_Half4 xy[2];
} _Half8; } _Half8;
using bit16_t = uint16_t; using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4; typedef bit16x4 _B16x4;
typedef struct _B16x8 { typedef struct _B16x8
_B16x4 xy[2]; {
_B16x4 xy[2];
} _B16x8; } _B16x8;
using _B8x8 = uint2; using _B8x8 = uint2;
using bit8_t = uint8_t; using bit8_t = uint8_t;
////// Non temporal load stores /////// ////// Non temporal load stores ///////
template <typename T> template <typename T>
__device__ __forceinline__ T load(T* addr) { __device__ __forceinline__ T load(T* addr)
return addr[0]; {
return addr[0];
} }
template <typename T> template <typename T>
__device__ __forceinline__ void store(T value, T* addr) { __device__ __forceinline__ void store(T value, T* addr)
addr[0] = value; {
addr[0] = value;
} }
template <typename T, int absz, int cbid, int blgp> template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
const _B16x4& inpB, const _B16x4& inpB,
const floatx4& inpC) { const floatx4& inpC)
if constexpr (std::is_same<T, _Float16>::value) { {
return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, if constexpr(std::is_same<T, _Float16>::value)
blgp); {
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp);
return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, }
blgp); else if constexpr(std::is_same<T, __hip_bfloat16>::value)
} else { {
static_assert(false, "unsupported 16b dtype"); return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, blgp);
} }
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T> template <typename T>
__device__ __forceinline__ float to_float(const T& inp) { __device__ __forceinline__ float to_float(const T& inp)
if constexpr (std::is_same<T, _Float16>::value) { {
return (float)inp; if constexpr(std::is_same<T, _Float16>::value)
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { {
return __bfloat162float(inp); return (float)inp;
} else { }
static_assert(false, "unsupported 16b dtype"); else if constexpr(std::is_same<T, __hip_bfloat16>::value)
} {
return __bfloat162float(inp);
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T> template <typename T>
__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { __device__ __forceinline__ float to_float_b16(const bit16_t& inp)
union tmpcvt { {
bit16_t u; union tmpcvt
_Float16 f; {
__hip_bfloat16 b; bit16_t u;
} t16; _Float16 f;
t16.u = inp; __hip_bfloat16 b;
if constexpr (std::is_same<T, _Float16>::value) { } t16;
return (float)t16.f; t16.u = inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { if constexpr(std::is_same<T, _Float16>::value)
return __bfloat162float(t16.b); {
} else { return (float)t16.f;
static_assert(false, "unsupported 16b dtype"); }
} else if constexpr(std::is_same<T, __hip_bfloat16>::value)
{
return __bfloat162float(t16.b);
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T> template <typename T>
__device__ __forceinline__ T from_float(const float& inp) { __device__ __forceinline__ T from_float(const float& inp)
if constexpr (std::is_same<T, _Float16>::value) { {
return (_Float16)inp; if constexpr(std::is_same<T, _Float16>::value)
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { {
return __float2bfloat16(inp); return (_Float16)inp;
} else { }
static_assert(false, "unsupported 16b dtype"); else if constexpr(std::is_same<T, __hip_bfloat16>::value)
} {
return __float2bfloat16(inp);
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T> template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp)
union tmpcvt { {
uint16_t u; union tmpcvt
_Float16 f; {
__hip_bfloat16 b; uint16_t u;
} t16; _Float16 f;
_B16x4 ret; __hip_bfloat16 b;
if constexpr (std::is_same<T, _Float16>::value) { } t16;
#pragma unroll _B16x4 ret;
for (int i = 0; i < 4; i++) { if constexpr(std::is_same<T, _Float16>::value)
t16.f = (_Float16)inp[i]; {
ret[i] = t16.u; #pragma unroll
} for(int i = 0; i < 4; i++)
return ret; {
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { t16.f = (_Float16)inp[i];
#pragma unroll ret[i] = t16.u;
for (int i = 0; i < 4; i++) { }
t16.b = __float2bfloat16(inp[i]); return ret;
ret[i] = t16.u; }
} else if constexpr(std::is_same<T, __hip_bfloat16>::value)
return ret; {
} else { #pragma unroll
static_assert(false, "unsupported 16b dtype"); for(int i = 0; i < 4; i++)
} {
t16.b = __float2bfloat16(inp[i]);
ret[i] = t16.u;
}
return ret;
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T> template <typename T>
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, const _B16x4& inp2)
const _B16x4& inp2) { {
union tmpcvt { union tmpcvt
uint16_t u; {
_Float16 f; uint16_t u;
__hip_bfloat16 b; _Float16 f;
} t1, t2, res; __hip_bfloat16 b;
_B16x4 ret; } t1, t2, res;
if constexpr (std::is_same<T, _Float16>::value) { _B16x4 ret;
#pragma unroll if constexpr(std::is_same<T, _Float16>::value)
for (int i = 0; i < 4; i++) { {
t1.u = inp1[i]; #pragma unroll
t2.u = inp2[i]; for(int i = 0; i < 4; i++)
res.f = t1.f + t2.f; {
ret[i] = res.u; t1.u = inp1[i];
} t2.u = inp2[i];
return ret; res.f = t1.f + t2.f;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { ret[i] = res.u;
#pragma unroll }
for (int i = 0; i < 4; i++) { return ret;
t1.u = inp1[i]; }
t2.u = inp2[i]; else if constexpr(std::is_same<T, __hip_bfloat16>::value)
res.b = t1.b + t2.b; {
ret[i] = res.u; #pragma unroll
} for(int i = 0; i < 4; i++)
return ret; {
} else { t1.u = inp1[i];
static_assert(false, "unsupported 16b dtype"); t2.u = inp2[i];
} res.b = t1.b + t2.b;
ret[i] = res.u;
}
return ret;
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE> template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, const float scale)
const float scale) { {
union alignas(16) { union alignas(16)
uint4 u4; {
_B16x8 u16x8; uint4 u4;
vllm::bf16_8_t b16x8; _B16x8 u16x8;
} tmp; vllm::bf16_8_t b16x8;
if constexpr (std::is_same<T, _Float16>::value) { } tmp;
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale); if constexpr(std::is_same<T, _Float16>::value)
return tmp.u16x8; {
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>( return tmp.u16x8;
input, scale); }
return tmp.u16x8; else if constexpr(std::is_same<T, __hip_bfloat16>::value)
} else { {
static_assert(false, "unsupported 16b dtype"); tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(input, scale);
} return tmp.u16x8;
}
else
{
static_assert(false, "unsupported 16b dtype");
}
} }
/////////////////////////////////////// ///////////////////////////////////////
// grid (num_seqs, num_partitions,num_heads/gqa_ratio) // grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size) // block (partition size)
template <typename scalar_t, typename cache_t, template <typename scalar_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, typename cache_t,
int HEAD_SIZE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename OUTT,
int BLOCK_SIZE,
int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const float scale,
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const int kv_block_stride,
float* __restrict__ max_logits, // [num_seqs, num_heads, const int kv_head_stride,
// max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, float* __restrict__ max_logits, // [num_seqs, num_heads,
// head_size] // max_num_partitions]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
int max_ctx_blocks, float k_scale, float v_scale, // head_size]
const float* __restrict__ fp8_out_scale_ptr) { OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; int max_ctx_blocks,
const int warpid = threadIdx.x / WARP_SIZE; float k_scale,
const int laneid = threadIdx.x % WARP_SIZE; float v_scale,
const int lane4id = laneid % 4; const float* __restrict__ fp8_out_scale_ptr)
{
const int seq_idx = blockIdx.x; constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int partition_idx = blockIdx.y; const int warpid = threadIdx.x / WARP_SIZE;
const int partition_size = blockDim.x; const int laneid = threadIdx.x % WARP_SIZE;
const int max_num_partitions = gridDim.y; const int lane4id = laneid % 4;
const int context_len = context_lens[seq_idx]; const int seq_idx = blockIdx.x;
const int partition_start_token_idx = partition_idx * partition_size; const int partition_idx = blockIdx.y;
// exit if partition is out of context for seq const int partition_size = blockDim.x;
if (partition_start_token_idx >= context_len) { const int max_num_partitions = gridDim.y;
return;
} const int context_len = context_lens[seq_idx];
constexpr int QHLOOP = const int partition_start_token_idx = partition_idx * partition_size;
DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, // exit if partition is out of context for seq
// total qheads =8, so qhloop is 2 if(partition_start_token_idx >= context_len)
constexpr int GQA_RATIO4 = 4 * QHLOOP; {
__shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; return;
__shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; }
_B16x8 Qlocal[QHLOOP]; constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
constexpr int x = 16 / sizeof(scalar_t); // total qheads =8, so qhloop is 2
constexpr int KHELOOP = HEAD_SIZE / x; constexpr int GQA_RATIO4 = 4 * QHLOOP;
_B16x8 Klocal[KHELOOP]; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
_B8x8 Klocalb8[KHELOOP]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
constexpr int VHELOOP = _B16x8 Qlocal[QHLOOP];
HEAD_SIZE / constexpr int x = 16 / sizeof(scalar_t);
WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int KHELOOP = HEAD_SIZE / x;
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 _B16x8 Klocal[KHELOOP];
// 8xtokens _B8x8 Klocalb8[KHELOOP];
_B16x8 Vlocal[VHELOOP][VTLOOP]; constexpr int VHELOOP =
_B8x8 Vlocalb8[VHELOOP][VTLOOP]; HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes
floatx4 dout[QHLOOP]; constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
float qk_max[QHLOOP]; // 8xtokens
#pragma unroll _B16x8 Vlocal[VHELOOP][VTLOOP];
for (int h = 0; h < QHLOOP; h++) { _B8x8 Vlocalb8[VHELOOP][VTLOOP];
dout[h] = {0}; floatx4 dout[QHLOOP];
qk_max[h] = -FLT_MAX; float qk_max[QHLOOP];
} #pragma unroll
for(int h = 0; h < QHLOOP; h++)
const int wg_start_head_idx = blockIdx.z * GQA_RATIO; {
const int wg_start_kv_head_idx = blockIdx.z; dout[h] = {0};
qk_max[h] = -FLT_MAX;
const int warp_start_token_idx = }
partition_start_token_idx + warpid * WARP_SIZE;
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
if (warp_start_token_idx >= context_len) { // warp out of context const int wg_start_kv_head_idx = blockIdx.z;
#pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) { const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE;
shared_qk_max[warpid][h] = -FLT_MAX;
shared_exp_sum[warpid][h] = 0.0f; if(warp_start_token_idx >= context_len)
} { // warp out of context
} else { // warp within context #pragma unroll
for(int h = 0; h < GQA_RATIO4; h++)
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); {
const int last_ctx_block = num_context_blocks - 1; shared_qk_max[warpid][h] = -FLT_MAX;
shared_exp_sum[warpid][h] = 0.0f;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; }
}
const int local_token_idx = threadIdx.x; else
const int global_token_idx = partition_start_token_idx + local_token_idx; { // warp within context
const int block_idx = (global_token_idx < context_len) const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
? global_token_idx / BLOCK_SIZE const int last_ctx_block = num_context_blocks - 1;
: last_ctx_block;
// fetch block number for q and k const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride const int local_token_idx = threadIdx.x;
const int64_t physical_block_number = const int global_token_idx = partition_start_token_idx + local_token_idx;
static_cast<int64_t>(block_table[block_idx]);
const int block_idx =
// fetch vphysical block numbers up front (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block;
constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; // fetch block number for q and k
int vphysical_blocks[VBLOCKS];
const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
if constexpr (GQA_RATIO < 12) {
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
}
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4;
#pragma unroll
for (int h = 0; h < QHLOOP - 1; h++) {
const int qhead_idx = h * 4 + lane4id;
Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
}
const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
if (final_qhead_idx < GQA_RATIO) {
Qlocal[QHLOOP - 1] =
q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
} else {
Qlocal[QHLOOP - 1].xy[0] = {0};
Qlocal[QHLOOP - 1].xy[1] = {0};
}
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
wg_start_kv_head_idx * kv_head_stride;
const int physical_block_offset =
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
// is already cast as _H8
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
}
} else {
constexpr int X = 16 / sizeof(cache_t);
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
const int head_elem = d * 8;
const int offset1 = head_elem / X;
const int offset2 = head_elem % X;
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
}
}
float alibi_slope[QHLOOP];
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int qhead_idx = h * 4 + lane4id;
alibi_slope[h] = (qhead_idx < GQA_RATIO)
? alibi_slopes[wg_start_head_idx + qhead_idx]
: 0.f;
}
}
// fetch vphysical block numbers up front
if constexpr (GQA_RATIO >= 12) {
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
}
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B16x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
}
}
}
} else {
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride // kv_block_stride
const int64_t vphysical_block_number = const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
static_cast<int64_t>(vphysical_blocks[b]);
const _B8x8* v_ptrh8b = // fetch vphysical block numbers up front
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
// iterate over each head elem (within head_size) int vphysical_blocks[VBLOCKS];
#pragma unroll
for (int h = 0; h < VHELOOP; h++) { const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
const int head_size_elem = h * WARP_SIZE + laneid; if constexpr(GQA_RATIO < 12)
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; {
// iterate over all velems within block #pragma unroll
#pragma unroll for(int b = 0; b < VBLOCKS; b++)
for (int d = 0; d < BLOCK_SIZE / 8; d++) { {
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; const int vblock_idx = warp_start_block_idx + b;
const _B8x8 Vlocalb8 = v_ptrh8be[d]; const int vblock_idx_ctx =
Vlocal[h][b * BLOCK_SIZE / 8 + d] = (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale); vphysical_blocks[b] = block_table[vblock_idx_ctx];
} }
} }
}
} // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
#pragma unroll const int qhead_elemh8 = laneid / 4;
for (int d = 0; d < KHELOOP; d++) { #pragma unroll
Klocal[d] = for(int h = 0; h < QHLOOP - 1; h++)
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale); {
} const int qhead_idx = h * 4 + lane4id;
} Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
}
#pragma unroll const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
for (int h = 0; h < QHLOOP; h++) { if(final_qhead_idx < GQA_RATIO)
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0], {
Klocal[0].xy[0], dout[h]); Qlocal[QHLOOP - 1] = q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1], }
Klocal[0].xy[1], dout[h]); else
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0], {
Klocal[1].xy[0], dout[h]); Qlocal[QHLOOP - 1].xy[0] = {0};
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1], Qlocal[QHLOOP - 1].xy[1] = {0};
Klocal[1].xy[1], dout[h]); }
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
Klocal[2].xy[0], dout[h]); const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1], wg_start_kv_head_idx * kv_head_stride;
Klocal[2].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0], const int physical_block_offset =
Klocal[3].xy[0], dout[h]); local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1], // is already cast as _H8
Klocal[3].xy[1], dout[h]); if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto)
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0], {
Klocal[4].xy[0], dout[h]); const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1], #pragma unroll
Klocal[4].xy[1], dout[h]); for(int d = 0; d < KHELOOP; d++)
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0], {
Klocal[5].xy[0], dout[h]); Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
Klocal[5].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
Klocal[6].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
Klocal[6].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
Klocal[7].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
Klocal[7].xy[1], dout[h]);
if constexpr (KHELOOP > 8) {
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
Klocal[8].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
Klocal[8].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
Klocal[9].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
Klocal[9].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
Klocal[10].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
Klocal[10].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
Klocal[11].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
Klocal[11].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
Klocal[12].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
Klocal[12].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
Klocal[13].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
Klocal[13].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
Klocal[14].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
Klocal[14].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
Klocal[15].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
Klocal[15].xy[1], dout[h]);
} // KHELOOP>8
dout[h] *= scale;
}
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
floatx4 tmp = {0};
#pragma unroll
for (int i = 0; i < 4; i++) {
const float B = (lane4id == i) ? 1.0f : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
// tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
}
dout[h] = tmp;
}
const int lane4_token_idx = 4 * (global_token_idx >> 2);
const int alibi_offset = lane4_token_idx - context_len + 1;
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] += alibi_slope[h] * (alibi_offset + i);
}
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX;
#pragma unroll
for (int i = 0; i < 4; i++) {
qk_max[h] = (lane4_token_idx + i < context_len)
? fmaxf(qk_max[h], dout[h][i])
: qk_max[h];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
}
}
float exp_sum[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f;
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] = (lane4_token_idx + i < context_len)
? __expf(dout[h][i] - qk_max[h])
: 0.0f;
exp_sum[h] += dout[h][i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
exp_sum[h] += __shfl_xor(exp_sum[h], mask);
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int head_idx = 4 * h + lane4id;
shared_qk_max[warpid][head_idx] = qk_max[h];
shared_exp_sum[warpid][head_idx] = exp_sum[h];
}
} // warp within context
__syncthreads();
const int num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr =
exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
float global_qk_max = -FLT_MAX;
float warp_qk_max[NWARPS];
const int head_idx = 4 * h + lane4id;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max[w] = shared_qk_max[w][head_idx];
global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
}
float global_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
global_exp_sum +=
shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
}
if (head_idx < GQA_RATIO) {
max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_qk_max;
exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_exp_sum;
}
const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
__expf(qk_max[h] - global_qk_max);
dout[h] *= global_inv_sum_scale;
}
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp
_B16x4 logits[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
logits[h] = from_floatx4<scalar_t>(dout[h]);
}
__shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
if (warp_start_token_idx >= context_len) { // warp out of context
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0};
}
}
} else { // warp in context
// iterate across heads
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
floatx4 acc = {0};
// iterate over tokens
acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
Vlocal[vh][5].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
Vlocal[vh][6].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
Vlocal[vh][7].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
Vlocal[vh][7].xy[1], acc);
vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
}
}
} // warp in context
__syncthreads();
if (warpid == 0) {
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
_B16x4 vout[QHLOOP][VHELOOP];
// iterate across heads
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout[qh][vh] = {0};
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
vout[qh][vh] =
addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
}
}
}
if (context_len > partition_size) {
scalar_t* out_ptr = out +
seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
const int out_num_partitions = max_num_partitions;
bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
const int head_size_elem = vh * WARP_SIZE + laneid;
#pragma unroll
for (int i = 0; i < 4; i++) {
const int head_idx = 4 * qh + i;
if (head_idx < GQA_RATIO) {
out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
HEAD_SIZE +
head_size_elem] = vout[qh][vh][i];
} }
} }
} else
} {
} // context_len > partition_size constexpr int X = 16 / sizeof(cache_t);
else { const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
bit8_t* final_out_ptr_b8; #pragma unroll
bit16_t* final_out_ptr_b16; for(int d = 0; d < KHELOOP; d++)
if constexpr (std::is_same<OUTT, bit8_t>::value) { {
final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE; const int head_elem = d * 8;
} else { const int offset1 = head_elem / X;
OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; const int offset2 = head_elem % X;
final_out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr); const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
} Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
#pragma unroll }
for (int qh = 0; qh < QHLOOP; qh++) { }
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) { float alibi_slope[QHLOOP];
const int head_size_elem = vh * WARP_SIZE + laneid; if(alibi_slopes != nullptr)
#pragma unroll {
for (int i = 0; i < 4; i++) { #pragma unroll
const int head_idx = 4 * qh + i; for(int h = 0; h < QHLOOP; h++)
if (head_idx < GQA_RATIO) { {
if constexpr (std::is_same<OUTT, bit8_t>::value) { const int qhead_idx = h * 4 + lane4id;
const float tmpf = alibi_slope[h] =
out_scale * to_float_b16<scalar_t>(vout[qh][vh][i]); (qhead_idx < GQA_RATIO) ? alibi_slopes[wg_start_head_idx + qhead_idx] : 0.f;
const OUTT tmp = hip_fp8(tmpf).data; }
final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE + }
head_size_elem] = tmp;
} else { // fetch vphysical block numbers up front
final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE + if constexpr(GQA_RATIO >= 12)
head_size_elem] = vout[qh][vh][i]; {
} #pragma unroll
for(int b = 0; b < VBLOCKS; b++)
{
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
} }
}
} }
}
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto)
{
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for(int b = 0; b < VBLOCKS; b++)
{
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number = static_cast<int64_t>(vphysical_blocks[b]);
const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for(int h = 0; h < VHELOOP; h++)
{
const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for(int d = 0; d < BLOCK_SIZE / 8; d++)
{
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
}
}
}
}
else
{
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for(int b = 0; b < VBLOCKS; b++)
{
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number = static_cast<int64_t>(vphysical_blocks[b]);
const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for(int h = 0; h < VHELOOP; h++)
{
const int head_size_elem = h * WARP_SIZE + laneid;
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for(int d = 0; d < BLOCK_SIZE / 8; d++)
{
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
}
}
}
}
if constexpr(KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto)
{
#pragma unroll
for(int d = 0; d < KHELOOP; d++)
{
Klocal[d] = scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
}
}
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0], Klocal[0].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1], Klocal[0].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0], Klocal[1].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1], Klocal[1].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0], Klocal[2].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1], Klocal[2].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0], Klocal[3].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1], Klocal[3].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0], Klocal[4].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1], Klocal[4].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0], Klocal[5].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1], Klocal[5].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0], Klocal[6].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1], Klocal[6].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0], Klocal[7].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1], Klocal[7].xy[1], dout[h]);
if constexpr(KHELOOP > 8)
{
dout[h] =
gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h]);
dout[h] =
gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h]);
} // KHELOOP>8
dout[h] *= scale;
}
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
floatx4 tmp = {0};
#pragma unroll
for(int i = 0; i < 4; i++)
{
const float B = (lane4id == i) ? 1.0f : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
// tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
}
dout[h] = tmp;
}
const int lane4_token_idx = 4 * (global_token_idx >> 2);
const int alibi_offset = lane4_token_idx - context_len + 1;
if(alibi_slopes != nullptr)
{
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
#pragma unroll
for(int i = 0; i < 4; i++)
{
dout[h][i] += alibi_slope[h] * (alibi_offset + i);
}
}
}
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
qk_max[h] = -FLT_MAX;
#pragma unroll
for(int i = 0; i < 4; i++)
{
qk_max[h] =
(lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h];
}
#pragma unroll
for(int mask = WARP_SIZE / 2; mask >= 4; mask /= 2)
{
qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
}
}
float exp_sum[QHLOOP];
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
exp_sum[h] = 0.0f;
#pragma unroll
for(int i = 0; i < 4; i++)
{
dout[h][i] =
(lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f;
exp_sum[h] += dout[h][i];
}
#pragma unroll
for(int mask = WARP_SIZE / 2; mask >= 4; mask /= 2)
{
exp_sum[h] += __shfl_xor(exp_sum[h], mask);
}
}
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
const int head_idx = 4 * h + lane4id;
shared_qk_max[warpid][head_idx] = qk_max[h];
shared_exp_sum[warpid][head_idx] = exp_sum[h];
}
} // warp within context
__syncthreads();
const int num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
float global_qk_max = -FLT_MAX;
float warp_qk_max[NWARPS];
const int head_idx = 4 * h + lane4id;
#pragma unroll
for(int w = 0; w < NWARPS; w++)
{
warp_qk_max[w] = shared_qk_max[w][head_idx];
global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
}
float global_exp_sum = 0.0f;
#pragma unroll
for(int w = 0; w < NWARPS; w++)
{
global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
}
if(head_idx < GQA_RATIO)
{
max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = global_qk_max;
exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = global_exp_sum;
}
const float global_inv_sum_scale =
__fdividef(1.f, global_exp_sum + 1e-6f) * __expf(qk_max[h] - global_qk_max);
dout[h] *= global_inv_sum_scale;
}
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp
_B16x4 logits[QHLOOP];
#pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
logits[h] = from_floatx4<scalar_t>(dout[h]);
} }
} // warpid == 0
__shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
if(warp_start_token_idx >= context_len)
{ // warp out of context
#pragma unroll
for(int qh = 0; qh < QHLOOP; qh++)
{
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
vout_shared[qh][vh][laneid][warpid] = {0};
}
}
}
else
{ // warp in context
// iterate across heads
#pragma unroll
for(int qh = 0; qh < QHLOOP; qh++)
{
// iterate over each v head elem (within head_size)
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
floatx4 acc = {0};
// iterate over tokens
acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh], Vlocal[vh][5].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh], Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh], Vlocal[vh][6].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh], Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh], Vlocal[vh][7].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh], Vlocal[vh][7].xy[1], acc);
vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
}
}
} // warp in context
__syncthreads();
if(warpid == 0)
{
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
_B16x4 vout[QHLOOP][VHELOOP];
// iterate across heads
#pragma unroll
for(int qh = 0; qh < QHLOOP; qh++)
{
// iterate over each v head elem (within head_size)
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
vout[qh][vh] = {0};
#pragma unroll
for(int w = 0; w < NWARPS; w++)
{
vout[qh][vh] = addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
}
}
}
if(context_len > partition_size)
{
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
const int out_num_partitions = max_num_partitions;
bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
#pragma unroll
for(int qh = 0; qh < QHLOOP; qh++)
{
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
const int head_size_elem = vh * WARP_SIZE + laneid;
#pragma unroll
for(int i = 0; i < 4; i++)
{
const int head_idx = 4 * qh + i;
if(head_idx < GQA_RATIO)
{
out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
HEAD_SIZE +
head_size_elem] = vout[qh][vh][i];
}
}
}
}
} // context_len > partition_size
else
{
bit8_t* final_out_ptr_b8;
bit16_t* final_out_ptr_b16;
if constexpr(std::is_same<OUTT, bit8_t>::value)
{
final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE;
}
else
{
OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
final_out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
}
#pragma unroll
for(int qh = 0; qh < QHLOOP; qh++)
{
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
const int head_size_elem = vh * WARP_SIZE + laneid;
#pragma unroll
for(int i = 0; i < 4; i++)
{
const int head_idx = 4 * qh + i;
if(head_idx < GQA_RATIO)
{
if constexpr(std::is_same<OUTT, bit8_t>::value)
{
const float tmpf =
out_scale * to_float_b16<scalar_t>(vout[qh][vh][i]);
const OUTT tmp = hip_fp8(tmpf).data;
final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE +
head_size_elem] = tmp;
}
else
{
final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE +
head_size_elem] = vout[qh][vh][i];
}
}
}
}
}
}
} // warpid == 0
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t,
int PARTITION_SIZE, int NPAR_LOOPS> typename OUTT,
__global__ int HEAD_SIZE,
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( int NUM_THREADS,
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] int PARTITION_SIZE,
const float* __restrict__ exp_sums, // [num_seqs, num_heads, int NPAR_LOOPS>
// max_num_partitions] __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ max_logits, // [num_seqs, num_heads, OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
// max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
const int* __restrict__ context_lens, // [num_seqs] // max_num_partitions]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const int num_heads = gridDim.x; // max_num_partitions, head_size]
const int head_idx = blockIdx.x; const int* __restrict__ context_lens, // [num_seqs]
const int seq_idx = blockIdx.y; const int max_num_partitions,
const int context_len = context_lens[seq_idx]; const float* __restrict__ fp8_out_scale_ptr)
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); {
if (num_partitions == 1) { const int num_heads = gridDim.x;
// if num_partitions==1, main kernel will write to out directly, no work in const int head_idx = blockIdx.x;
// reduction kernel const int seq_idx = blockIdx.y;
return; const int context_len = context_lens[seq_idx];
} const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
if(num_partitions == 1)
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; {
const int warpid = threadIdx.x / WARP_SIZE; // if num_partitions==1, main kernel will write to out directly, no work in
const int laneid = threadIdx.x % WARP_SIZE; // reduction kernel
return;
__shared__ float shared_global_exp_sum; }
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
if (warpid == 0) { const int laneid = threadIdx.x % WARP_SIZE;
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions + __shared__ float shared_global_exp_sum;
head_idx * max_num_partitions; // max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
// valid partition is the last valid partition in case threadid > num
// partitions if(warpid == 0)
int valid_partition[NPAR_LOOPS]; {
float reg_max_logit[NPAR_LOOPS]; const float* max_logits_ptr =
const int last_valid_partition = num_partitions - 1; max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
#pragma unroll // valid partition is the last valid partition in case threadid > num
for (int i = 0; i < NPAR_LOOPS; i++) { // partitions
const int partition_no = i * WARP_SIZE + threadIdx.x; int valid_partition[NPAR_LOOPS];
valid_partition[i] = float reg_max_logit[NPAR_LOOPS];
(partition_no < num_partitions) ? partition_no : last_valid_partition; const int last_valid_partition = num_partitions - 1;
}
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; {
} const int partition_no = i * WARP_SIZE + threadIdx.x;
float max_logit = reg_max_logit[0]; valid_partition[i] =
#pragma unroll (partition_no < num_partitions) ? partition_no : last_valid_partition;
for (int i = 1; i < NPAR_LOOPS; i++) { }
max_logit = fmaxf(max_logit, reg_max_logit[i]); #pragma unroll
} for(int i = 0; i < NPAR_LOOPS; i++)
{
#pragma unroll reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { }
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); float max_logit = reg_max_logit[0];
} #pragma unroll
for(int i = 1; i < NPAR_LOOPS; i++)
const float* exp_sums_ptr = exp_sums + {
seq_idx * num_heads * max_num_partitions + max_logit = fmaxf(max_logit, reg_max_logit[i]);
head_idx * max_num_partitions; }
float rescaled_exp_sum[NPAR_LOOPS]; #pragma unroll
#pragma unroll for(int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
for (int i = 0; i < NPAR_LOOPS; i++) { {
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
} }
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { const float* exp_sums_ptr =
const int partition_no = i * WARP_SIZE + threadIdx.x; exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit) float rescaled_exp_sum[NPAR_LOOPS];
: 0.0f; #pragma unroll
} for(int i = 0; i < NPAR_LOOPS; i++)
float global_exp_sum = rescaled_exp_sum[0]; {
#pragma unroll rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
for (int i = 1; i < NPAR_LOOPS; i++) { }
global_exp_sum += rescaled_exp_sum[i]; #pragma unroll
} for(int i = 0; i < NPAR_LOOPS; i++)
#pragma unroll {
for (int i = 0; i < NPAR_LOOPS; i++) { const int partition_no = i * WARP_SIZE + threadIdx.x;
const int partition_no = i * WARP_SIZE + threadIdx.x; rescaled_exp_sum[i] *=
shared_exp_sums[partition_no] = rescaled_exp_sum[i]; (partition_no < num_partitions) ? expf(reg_max_logit[i] - max_logit) : 0.0f;
} }
float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for(int i = 1; i < NPAR_LOOPS; i++)
global_exp_sum += __shfl_xor(global_exp_sum, mask); {
} global_exp_sum += rescaled_exp_sum[i];
if (threadIdx.x == 0) { }
shared_global_exp_sum = global_exp_sum; #pragma unroll
} for(int i = 0; i < NPAR_LOOPS; i++)
} // warpid == 0 {
const scalar_t* tmp_out_ptr = const int partition_no = i * WARP_SIZE + threadIdx.x;
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + shared_exp_sums[partition_no] = rescaled_exp_sum[i];
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; }
constexpr int MAX_NPAR = 64;
scalar_t tmps[MAX_NPAR]; #pragma unroll
const float dzero = 0.0f; for(int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
#pragma unroll {
for (int j = 0; j < MAX_NPAR; j++) { global_exp_sum += __shfl_xor(global_exp_sum, mask);
tmps[j] = from_float<scalar_t>(dzero); }
} if(threadIdx.x == 0)
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; {
const int num_partition_offset = (num_partitions)*HEAD_SIZE; shared_global_exp_sum = global_exp_sum;
int idx = 0; }
} // warpid == 0
constexpr int JCHUNK = 16; const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
#pragma unroll constexpr int MAX_NPAR = 64;
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { scalar_t tmps[MAX_NPAR];
// lastj is last valid partition const float dzero = 0.0f;
const int lastj_offset = #pragma unroll
(j < num_partition_offset) ? j : last_partition_offset; for(int j = 0; j < MAX_NPAR; j++)
tmps[idx] = tmp_out_ptr[lastj_offset]; {
idx++; tmps[j] = from_float<scalar_t>(dzero);
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
} }
} const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
for (int p = 1; p < NPAR_LOOPS; p++) { #pragma unroll
if (num_partitions > p * MAX_NPAR) { for(int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE)
idx = 0; {
#pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition // lastj is last valid partition
const int lastj_offset = const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
(j < num_partition_offset) ? j : last_partition_offset; tmps[idx] = tmp_out_ptr[lastj_offset];
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++; idx++;
} }
__syncthreads();
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) { if(num_partitions > JCHUNK)
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; {
} #pragma unroll
} for(int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; j += HEAD_SIZE)
} {
const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
const float inv_global_exp_sum = tmps[idx] = tmp_out_ptr[lastj_offset];
__fdividef(1.0f, shared_global_exp_sum + 1e-6f); idx++;
// const float out_scale = (fp8_out_scale_ptr != nullptr) ? }
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale = if(num_partitions > 2 * JCHUNK)
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; {
acc *= inv_global_exp_sum; #pragma unroll
acc *= out_scale; for(int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE)
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; {
if constexpr (std::is_same<OUTT, bit8_t>::value) { const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
out_ptr[threadIdx.x] = hip_fp8(acc).data; tmps[idx] = tmp_out_ptr[lastj_offset];
} else { idx++;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc); }
} }
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for(int j = 0; j < JCHUNK; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if(num_partitions > JCHUNK)
{
#pragma unroll
for(int j = JCHUNK; j < 2 * JCHUNK; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if(num_partitions > 2 * JCHUNK)
{
#pragma unroll
for(int j = 2 * JCHUNK; j < MAX_NPAR; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
for(int p = 1; p < NPAR_LOOPS; p++)
{
if(num_partitions > p * MAX_NPAR)
{
idx = 0;
#pragma unroll
for(int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE)
{
// lastj is last valid partition
const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for(int j = 0; j < MAX_NPAR; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
}
}
}
const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f);
// const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;
acc *= out_scale;
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
if constexpr(std::is_same<OUTT, bit8_t>::value)
{
out_ptr[threadIdx.x] = hip_fp8(acc).data;
}
else
{
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
} }
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, typename cache_t, template <typename scalar_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, typename cache_t,
int HEAD_SIZE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename OUTT,
int BLOCK_SIZE,
int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const float scale,
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const int kv_block_stride,
float* __restrict__ max_logits, // [num_seqs, num_heads, const int kv_head_stride,
// max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, float* __restrict__ max_logits, // [num_seqs, num_heads,
// head_size] // max_num_partitions]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
int max_ctx_blocks, float k_scale, float v_scale, // head_size]
const float* __restrict__ fp8_out_scale_ptr) { OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
UNREACHABLE_CODE int max_ctx_blocks,
float k_scale,
float v_scale,
const float* __restrict__ fp8_out_scale_ptr)
{
UNREACHABLE_CODE
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t,
int PARTITION_SIZE, int NPAR_LOOPS> typename OUTT,
__global__ int HEAD_SIZE,
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( int NUM_THREADS,
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] int PARTITION_SIZE,
const float* __restrict__ exp_sums, // [num_seqs, num_heads, int NPAR_LOOPS>
// max_num_partitions] __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ max_logits, // [num_seqs, num_heads, OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
// max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
const int* __restrict__ context_lens, // [num_seqs] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions, const int max_num_partitions,
const float* __restrict__ fp8_out_scale_ptr){UNREACHABLE_CODE} const float* __restrict__ fp8_out_scale_ptr)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment