Commit 5a683756 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Re-format kernel

parent b618806b
...@@ -24,41 +24,42 @@ ...@@ -24,41 +24,42 @@
#include "attention/dtype_fp8.cuh" #include "attention/dtype_fp8.cuh"
#include "quantization/fp8/amd/quant_utils.cuh" #include "quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ #if defined(__HIPCC__) && \
defined(__gfx941__) || defined(__gfx942__)) (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__ #define __HIP__MI300_MI250__
#endif #endif
#if defined(NDEBUG) #if defined(NDEBUG)
#undef NDEBUG #undef NDEBUG
#include <assert.h> #include <assert.h>
#define UNREACHABLE_CODE assert(false); #define UNREACHABLE_CODE assert(false);
#define NDEBUG #define NDEBUG
#else #else
#define UNREACHABLE_CODE assert(false); #define UNREACHABLE_CODE assert(false);
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float16x4 = using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
__attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
typedef float16x4 _Half4; typedef float16x4 _Half4;
typedef struct _Half8 { typedef struct _Half8
{
_Half4 xy[2]; _Half4 xy[2];
} _Half8; } _Half8;
using bit16_t = uint16_t; using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4; typedef bit16x4 _B16x4;
typedef struct _B16x8 { typedef struct _B16x8
{
_B16x4 xy[2]; _B16x4 xy[2];
} _B16x8; } _B16x8;
...@@ -68,144 +69,191 @@ using bit8_t = uint8_t; ...@@ -68,144 +69,191 @@ using bit8_t = uint8_t;
////// Non temporal load stores /////// ////// Non temporal load stores ///////
template <typename T> template <typename T>
__device__ __forceinline__ T load(T* addr) { __device__ __forceinline__ T load(T* addr)
{
return addr[0]; return addr[0];
} }
template <typename T> template <typename T>
__device__ __forceinline__ void store(T value, T* addr) { __device__ __forceinline__ void store(T value, T* addr)
{
addr[0] = value; addr[0] = value;
} }
template <typename T, int absz, int cbid, int blgp> template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
const _B16x4& inpB, const _B16x4& inpB,
const floatx4& inpC) { const floatx4& inpC)
if constexpr (std::is_same<T, _Float16>::value) { {
return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, if constexpr(std::is_same<T, _Float16>::value)
blgp); {
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp);
return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, }
blgp); else if constexpr(std::is_same<T, __hip_bfloat16>::value)
} else { {
return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, blgp);
}
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ float to_float(const T& inp) { __device__ __forceinline__ float to_float(const T& inp)
if constexpr (std::is_same<T, _Float16>::value) { {
if constexpr(std::is_same<T, _Float16>::value)
{
return (float)inp; return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
else if constexpr(std::is_same<T, __hip_bfloat16>::value)
{
return __bfloat162float(inp); return __bfloat162float(inp);
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { __device__ __forceinline__ float to_float_b16(const bit16_t& inp)
union tmpcvt { {
union tmpcvt
{
bit16_t u; bit16_t u;
_Float16 f; _Float16 f;
__hip_bfloat16 b; __hip_bfloat16 b;
} t16; } t16;
t16.u = inp; t16.u = inp;
if constexpr (std::is_same<T, _Float16>::value) { if constexpr(std::is_same<T, _Float16>::value)
{
return (float)t16.f; return (float)t16.f;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
else if constexpr(std::is_same<T, __hip_bfloat16>::value)
{
return __bfloat162float(t16.b); return __bfloat162float(t16.b);
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ T from_float(const float& inp) { __device__ __forceinline__ T from_float(const float& inp)
if constexpr (std::is_same<T, _Float16>::value) { {
if constexpr(std::is_same<T, _Float16>::value)
{
return (_Float16)inp; return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
else if constexpr(std::is_same<T, __hip_bfloat16>::value)
{
return __float2bfloat16(inp); return __float2bfloat16(inp);
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp)
union tmpcvt { {
union tmpcvt
{
uint16_t u; uint16_t u;
_Float16 f; _Float16 f;
__hip_bfloat16 b; __hip_bfloat16 b;
} t16; } t16;
_B16x4 ret; _B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) { if constexpr(std::is_same<T, _Float16>::value)
#pragma unroll {
for (int i = 0; i < 4; i++) { #pragma unroll
for(int i = 0; i < 4; i++)
{
t16.f = (_Float16)inp[i]; t16.f = (_Float16)inp[i];
ret[i] = t16.u; ret[i] = t16.u;
} }
return ret; return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
#pragma unroll else if constexpr(std::is_same<T, __hip_bfloat16>::value)
for (int i = 0; i < 4; i++) { {
#pragma unroll
for(int i = 0; i < 4; i++)
{
t16.b = __float2bfloat16(inp[i]); t16.b = __float2bfloat16(inp[i]);
ret[i] = t16.u; ret[i] = t16.u;
} }
return ret; return ret;
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, const _B16x4& inp2)
const _B16x4& inp2) { {
union tmpcvt { union tmpcvt
{
uint16_t u; uint16_t u;
_Float16 f; _Float16 f;
__hip_bfloat16 b; __hip_bfloat16 b;
} t1, t2, res; } t1, t2, res;
_B16x4 ret; _B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) { if constexpr(std::is_same<T, _Float16>::value)
#pragma unroll {
for (int i = 0; i < 4; i++) { #pragma unroll
for(int i = 0; i < 4; i++)
{
t1.u = inp1[i]; t1.u = inp1[i];
t2.u = inp2[i]; t2.u = inp2[i];
res.f = t1.f + t2.f; res.f = t1.f + t2.f;
ret[i] = res.u; ret[i] = res.u;
} }
return ret; return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
#pragma unroll else if constexpr(std::is_same<T, __hip_bfloat16>::value)
for (int i = 0; i < 4; i++) { {
#pragma unroll
for(int i = 0; i < 4; i++)
{
t1.u = inp1[i]; t1.u = inp1[i];
t2.u = inp2[i]; t2.u = inp2[i];
res.b = t1.b + t2.b; res.b = t1.b + t2.b;
ret[i] = res.u; ret[i] = res.u;
} }
return ret; return ret;
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE> template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, const float scale)
const float scale) { {
union alignas(16) { union alignas(16)
{
uint4 u4; uint4 u4;
_B16x8 u16x8; _B16x8 u16x8;
vllm::bf16_8_t b16x8; vllm::bf16_8_t b16x8;
} tmp; } tmp;
if constexpr (std::is_same<T, _Float16>::value) { if constexpr(std::is_same<T, _Float16>::value)
{
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale); tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
return tmp.u16x8; return tmp.u16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) { }
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>( else if constexpr(std::is_same<T, __hip_bfloat16>::value)
input, scale); {
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(input, scale);
return tmp.u16x8; return tmp.u16x8;
} else { }
else
{
static_assert(false, "unsupported 16b dtype"); static_assert(false, "unsupported 16b dtype");
} }
} }
...@@ -214,9 +262,13 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, ...@@ -214,9 +262,13 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
// grid (num_seqs, num_partitions,num_heads/gqa_ratio) // grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size) // block (partition size)
template <typename scalar_t, typename cache_t, template <typename scalar_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, typename cache_t,
int HEAD_SIZE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename OUTT,
int BLOCK_SIZE,
int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
...@@ -224,20 +276,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -224,20 +276,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale, int max_ctx_blocks,
const float* __restrict__ fp8_out_scale_ptr) { float k_scale,
float v_scale,
const float* __restrict__ fp8_out_scale_ptr)
{
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE;
...@@ -251,11 +309,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -251,11 +309,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const int context_len = context_lens[seq_idx]; const int context_len = context_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size; const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq // exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) { if(partition_start_token_idx >= context_len)
{
return; return;
} }
constexpr int QHLOOP = constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
// total qheads =8, so qhloop is 2 // total qheads =8, so qhloop is 2
constexpr int GQA_RATIO4 = 4 * QHLOOP; constexpr int GQA_RATIO4 = 4 * QHLOOP;
__shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
...@@ -266,16 +324,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -266,16 +324,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
_B16x8 Klocal[KHELOOP]; _B16x8 Klocal[KHELOOP];
_B8x8 Klocalb8[KHELOOP]; _B8x8 Klocalb8[KHELOOP];
constexpr int VHELOOP = constexpr int VHELOOP =
HEAD_SIZE / HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes
WARP_SIZE; // v head_size dimension is distributed across lanes
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
// 8xtokens // 8xtokens
_B16x8 Vlocal[VHELOOP][VTLOOP]; _B16x8 Vlocal[VHELOOP][VTLOOP];
_B8x8 Vlocalb8[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP];
floatx4 dout[QHLOOP]; floatx4 dout[QHLOOP];
float qk_max[QHLOOP]; float qk_max[QHLOOP];
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
dout[h] = {0}; dout[h] = {0};
qk_max[h] = -FLT_MAX; qk_max[h] = -FLT_MAX;
} }
...@@ -283,16 +341,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -283,16 +341,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const int wg_start_head_idx = blockIdx.z * GQA_RATIO; const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z; const int wg_start_kv_head_idx = blockIdx.z;
const int warp_start_token_idx = const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE;
partition_start_token_idx + warpid * WARP_SIZE;
if (warp_start_token_idx >= context_len) { // warp out of context if(warp_start_token_idx >= context_len)
#pragma unroll { // warp out of context
for (int h = 0; h < GQA_RATIO4; h++) { #pragma unroll
for(int h = 0; h < GQA_RATIO4; h++)
{
shared_qk_max[warpid][h] = -FLT_MAX; shared_qk_max[warpid][h] = -FLT_MAX;
shared_exp_sum[warpid][h] = 0.0f; shared_exp_sum[warpid][h] = 0.0f;
} }
} else { // warp within context }
else
{ // warp within context
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1; const int last_ctx_block = num_context_blocks - 1;
...@@ -302,23 +363,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -302,23 +363,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const int local_token_idx = threadIdx.x; const int local_token_idx = threadIdx.x;
const int global_token_idx = partition_start_token_idx + local_token_idx; const int global_token_idx = partition_start_token_idx + local_token_idx;
const int block_idx = (global_token_idx < context_len) const int block_idx =
? global_token_idx / BLOCK_SIZE (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block;
: last_ctx_block;
// fetch block number for q and k // fetch block number for q and k
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride // kv_block_stride
const int64_t physical_block_number = const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
static_cast<int64_t>(block_table[block_idx]);
// fetch vphysical block numbers up front // fetch vphysical block numbers up front
constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
int vphysical_blocks[VBLOCKS]; int vphysical_blocks[VBLOCKS];
const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
if constexpr (GQA_RATIO < 12) { if constexpr(GQA_RATIO < 12)
#pragma unroll {
for (int b = 0; b < VBLOCKS; b++) { #pragma unroll
for(int b = 0; b < VBLOCKS; b++)
{
const int vblock_idx = warp_start_block_idx + b; const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx = const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
...@@ -327,20 +388,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -327,20 +388,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr = const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr); const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4; const int qhead_elemh8 = laneid / 4;
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP - 1; h++) { for(int h = 0; h < QHLOOP - 1; h++)
{
const int qhead_idx = h * 4 + lane4id; const int qhead_idx = h * 4 + lane4id;
Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
} }
const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
if (final_qhead_idx < GQA_RATIO) { if(final_qhead_idx < GQA_RATIO)
Qlocal[QHLOOP - 1] = {
q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; Qlocal[QHLOOP - 1] = q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
} else { }
else
{
Qlocal[QHLOOP - 1].xy[0] = {0}; Qlocal[QHLOOP - 1].xy[0] = {0};
Qlocal[QHLOOP - 1].xy[1] = {0}; Qlocal[QHLOOP - 1].xy[1] = {0};
} }
...@@ -351,17 +414,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -351,17 +414,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const int physical_block_offset = const int physical_block_offset =
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
// is already cast as _H8 // is already cast as _H8
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto)
{
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr); const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
#pragma unroll #pragma unroll
for (int d = 0; d < KHELOOP; d++) { for(int d = 0; d < KHELOOP; d++)
{
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
} }
} else { }
else
{
constexpr int X = 16 / sizeof(cache_t); constexpr int X = 16 / sizeof(cache_t);
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
#pragma unroll #pragma unroll
for (int d = 0; d < KHELOOP; d++) { for(int d = 0; d < KHELOOP; d++)
{
const int head_elem = d * 8; const int head_elem = d * 8;
const int offset1 = head_elem / X; const int offset1 = head_elem / X;
const int offset2 = head_elem % X; const int offset2 = head_elem % X;
...@@ -371,20 +439,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -371,20 +439,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
float alibi_slope[QHLOOP]; float alibi_slope[QHLOOP];
if (alibi_slopes != nullptr) { if(alibi_slopes != nullptr)
#pragma unroll {
for (int h = 0; h < QHLOOP; h++) { #pragma unroll
for(int h = 0; h < QHLOOP; h++)
{
const int qhead_idx = h * 4 + lane4id; const int qhead_idx = h * 4 + lane4id;
alibi_slope[h] = (qhead_idx < GQA_RATIO) alibi_slope[h] =
? alibi_slopes[wg_start_head_idx + qhead_idx] (qhead_idx < GQA_RATIO) ? alibi_slopes[wg_start_head_idx + qhead_idx] : 0.f;
: 0.f;
} }
} }
// fetch vphysical block numbers up front // fetch vphysical block numbers up front
if constexpr (GQA_RATIO >= 12) { if constexpr(GQA_RATIO >= 12)
#pragma unroll {
for (int b = 0; b < VBLOCKS; b++) { #pragma unroll
for(int b = 0; b < VBLOCKS; b++)
{
const int vblock_idx = warp_start_block_idx + b; const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx = const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
...@@ -393,48 +464,53 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -393,48 +464,53 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto)
{
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr); const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block // iterate over each v block
#pragma unroll #pragma unroll
for (int b = 0; b < VBLOCKS; b++) { for(int b = 0; b < VBLOCKS; b++)
{
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride // kv_block_stride
const int64_t vphysical_block_number = const int64_t vphysical_block_number = static_cast<int64_t>(vphysical_blocks[b]);
static_cast<int64_t>(vphysical_blocks[b]); const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
const _B16x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size) // iterate over each head elem (within head_size)
#pragma unroll #pragma unroll
for (int h = 0; h < VHELOOP; h++) { for(int h = 0; h < VHELOOP; h++)
{
const int head_size_elem = h * WARP_SIZE + laneid; const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block // iterate over all velems within block
#pragma unroll #pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) { for(int d = 0; d < BLOCK_SIZE / 8; d++)
{
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
} }
} }
} }
} else { }
else
{
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr); const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block // iterate over each v block
#pragma unroll #pragma unroll
for (int b = 0; b < VBLOCKS; b++) { for(int b = 0; b < VBLOCKS; b++)
{
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride // kv_block_stride
const int64_t vphysical_block_number = const int64_t vphysical_block_number = static_cast<int64_t>(vphysical_blocks[b]);
static_cast<int64_t>(vphysical_blocks[b]); const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
const _B8x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size) // iterate over each head elem (within head_size)
#pragma unroll #pragma unroll
for (int h = 0; h < VHELOOP; h++) { for(int h = 0; h < VHELOOP; h++)
{
const int head_size_elem = h * WARP_SIZE + laneid; const int head_size_elem = h * WARP_SIZE + laneid;
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block // iterate over all velems within block
#pragma unroll #pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) { for(int d = 0; d < BLOCK_SIZE / 8; d++)
{
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d]; const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] = Vlocal[h][b * BLOCK_SIZE / 8 + d] =
...@@ -444,91 +520,80 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -444,91 +520,80 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
} }
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { if constexpr(KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto)
#pragma unroll {
for (int d = 0; d < KHELOOP; d++) { #pragma unroll
Klocal[d] = for(int d = 0; d < KHELOOP; d++)
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale); {
} Klocal[d] = scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
} }
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) { #pragma unroll
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0], for(int h = 0; h < QHLOOP; h++)
Klocal[0].xy[0], dout[h]); {
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1], dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0], Klocal[0].xy[0], dout[h]);
Klocal[0].xy[1], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1], Klocal[0].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0], dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0], Klocal[1].xy[0], dout[h]);
Klocal[1].xy[0], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1], Klocal[1].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1], dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0], Klocal[2].xy[0], dout[h]);
Klocal[1].xy[1], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1], Klocal[2].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0], dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0], Klocal[3].xy[0], dout[h]);
Klocal[2].xy[0], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1], Klocal[3].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1], dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0], Klocal[4].xy[0], dout[h]);
Klocal[2].xy[1], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1], Klocal[4].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0], dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0], Klocal[5].xy[0], dout[h]);
Klocal[3].xy[0], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1], Klocal[5].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1], dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0], Klocal[6].xy[0], dout[h]);
Klocal[3].xy[1], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1], Klocal[6].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0], dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0], Klocal[7].xy[0], dout[h]);
Klocal[4].xy[0], dout[h]); dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1], Klocal[7].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1], if constexpr(KHELOOP > 8)
Klocal[4].xy[1], dout[h]); {
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0], dout[h] =
Klocal[5].xy[0], dout[h]); gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1], dout[h] =
Klocal[5].xy[1], dout[h]); gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0], dout[h] =
Klocal[6].xy[0], dout[h]); gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1], dout[h] =
Klocal[6].xy[1], dout[h]); gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0], dout[h] =
Klocal[7].xy[0], dout[h]); gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1], dout[h] =
Klocal[7].xy[1], dout[h]); gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h]);
if constexpr (KHELOOP > 8) { dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0], gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h]);
Klocal[8].xy[0], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1], gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h]);
Klocal[8].xy[1], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0], gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h]);
Klocal[9].xy[0], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1], gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h]);
Klocal[9].xy[1], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0], gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h]);
Klocal[10].xy[0], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1], gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h]);
Klocal[10].xy[1], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0], gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h]);
Klocal[11].xy[0], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1], gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h]);
Klocal[11].xy[1], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0], gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h]);
Klocal[12].xy[0], dout[h]); dout[h] =
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1], gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h]);
Klocal[12].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
Klocal[13].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
Klocal[13].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
Klocal[14].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
Klocal[14].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
Klocal[15].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
Klocal[15].xy[1], dout[h]);
} // KHELOOP>8 } // KHELOOP>8
dout[h] *= scale; dout[h] *= scale;
} }
// transpose dout so that 4 token ids are in each lane, and 4 heads are across // transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes // 4 lanes
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
floatx4 tmp = {0}; floatx4 tmp = {0};
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++)
{
const float B = (lane4id == i) ? 1.0f : 0.0f; const float B = (lane4id == i) ? 1.0f : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
...@@ -539,50 +604,58 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -539,50 +604,58 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const int lane4_token_idx = 4 * (global_token_idx >> 2); const int lane4_token_idx = 4 * (global_token_idx >> 2);
const int alibi_offset = lane4_token_idx - context_len + 1; const int alibi_offset = lane4_token_idx - context_len + 1;
if (alibi_slopes != nullptr) { if(alibi_slopes != nullptr)
#pragma unroll {
for (int h = 0; h < QHLOOP; h++) { #pragma unroll
#pragma unroll for(int h = 0; h < QHLOOP; h++)
for (int i = 0; i < 4; i++) { {
#pragma unroll
for(int i = 0; i < 4; i++)
{
dout[h][i] += alibi_slope[h] * (alibi_offset + i); dout[h][i] += alibi_slope[h] * (alibi_offset + i);
} }
} }
} }
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
qk_max[h] = -FLT_MAX; qk_max[h] = -FLT_MAX;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++)
qk_max[h] = (lane4_token_idx + i < context_len) {
? fmaxf(qk_max[h], dout[h][i]) qk_max[h] =
: qk_max[h]; (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h];
} }
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { for(int mask = WARP_SIZE / 2; mask >= 4; mask /= 2)
{
qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
} }
} }
float exp_sum[QHLOOP]; float exp_sum[QHLOOP];
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
exp_sum[h] = 0.0f; exp_sum[h] = 0.0f;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++)
dout[h][i] = (lane4_token_idx + i < context_len) {
? __expf(dout[h][i] - qk_max[h]) dout[h][i] =
: 0.0f; (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f;
exp_sum[h] += dout[h][i]; exp_sum[h] += dout[h][i];
} }
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { for(int mask = WARP_SIZE / 2; mask >= 4; mask /= 2)
{
exp_sum[h] += __shfl_xor(exp_sum[h], mask); exp_sum[h] += __shfl_xor(exp_sum[h], mask);
} }
} }
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
const int head_idx = 4 * h + lane4id; const int head_idx = 4 * h + lane4id;
shared_qk_max[warpid][head_idx] = qk_max[h]; shared_qk_max[warpid][head_idx] = qk_max[h];
shared_exp_sum[warpid][head_idx] = exp_sum[h]; shared_exp_sum[warpid][head_idx] = exp_sum[h];
...@@ -592,95 +665,86 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -592,95 +665,86 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
__syncthreads(); __syncthreads();
const int num_heads = gridDim.z * GQA_RATIO; const int num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr = float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr = #pragma unroll
exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; for(int h = 0; h < QHLOOP; h++)
#pragma unroll {
for (int h = 0; h < QHLOOP; h++) {
float global_qk_max = -FLT_MAX; float global_qk_max = -FLT_MAX;
float warp_qk_max[NWARPS]; float warp_qk_max[NWARPS];
const int head_idx = 4 * h + lane4id; const int head_idx = 4 * h + lane4id;
#pragma unroll #pragma unroll
for (int w = 0; w < NWARPS; w++) { for(int w = 0; w < NWARPS; w++)
{
warp_qk_max[w] = shared_qk_max[w][head_idx]; warp_qk_max[w] = shared_qk_max[w][head_idx];
global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
} }
float global_exp_sum = 0.0f; float global_exp_sum = 0.0f;
#pragma unroll #pragma unroll
for (int w = 0; w < NWARPS; w++) { for(int w = 0; w < NWARPS; w++)
global_exp_sum += {
shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
} }
if (head_idx < GQA_RATIO) { if(head_idx < GQA_RATIO)
max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = {
global_qk_max; max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = global_qk_max;
exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = global_exp_sum;
global_exp_sum; }
} const float global_inv_sum_scale =
const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * __fdividef(1.f, global_exp_sum + 1e-6f) * __expf(qk_max[h] - global_qk_max);
__expf(qk_max[h] - global_qk_max);
dout[h] *= global_inv_sum_scale; dout[h] *= global_inv_sum_scale;
} }
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp // are 4x16 tokens across warp
_B16x4 logits[QHLOOP]; _B16x4 logits[QHLOOP];
#pragma unroll #pragma unroll
for (int h = 0; h < QHLOOP; h++) { for(int h = 0; h < QHLOOP; h++)
{
logits[h] = from_floatx4<scalar_t>(dout[h]); logits[h] = from_floatx4<scalar_t>(dout[h]);
} }
__shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
if (warp_start_token_idx >= context_len) { // warp out of context if(warp_start_token_idx >= context_len)
#pragma unroll { // warp out of context
for (int qh = 0; qh < QHLOOP; qh++) { #pragma unroll
#pragma unroll for(int qh = 0; qh < QHLOOP; qh++)
for (int vh = 0; vh < VHELOOP; vh++) { {
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
vout_shared[qh][vh][laneid][warpid] = {0}; vout_shared[qh][vh][laneid][warpid] = {0};
} }
} }
} else { // warp in context }
// iterate across heads else
#pragma unroll { // warp in context
for (int qh = 0; qh < QHLOOP; qh++) { // iterate across heads
// iterate over each v head elem (within head_size) #pragma unroll
#pragma unroll for(int qh = 0; qh < QHLOOP; qh++)
for (int vh = 0; vh < VHELOOP; vh++) { {
// iterate over each v head elem (within head_size)
#pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
floatx4 acc = {0}; floatx4 acc = {0};
// iterate over tokens // iterate over tokens
acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0], acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1], acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0], acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1], acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0], acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1], acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh], Vlocal[vh][5].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh], Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0], acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh], Vlocal[vh][6].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh], Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1], acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh], Vlocal[vh][7].xy[0], acc);
acc); acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh], Vlocal[vh][7].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
Vlocal[vh][5].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
Vlocal[vh][6].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
Vlocal[vh][7].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
Vlocal[vh][7].xy[1], acc);
vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc); vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
} }
} }
...@@ -688,42 +752,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -688,42 +752,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
__syncthreads(); __syncthreads();
if (warpid == 0) { if(warpid == 0)
{
// const float out_scale = (fp8_out_scale_ptr != nullptr) ? // const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale = const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
_B16x4 vout[QHLOOP][VHELOOP]; _B16x4 vout[QHLOOP][VHELOOP];
// iterate across heads // iterate across heads
#pragma unroll #pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) { for(int qh = 0; qh < QHLOOP; qh++)
// iterate over each v head elem (within head_size) {
#pragma unroll // iterate over each v head elem (within head_size)
for (int vh = 0; vh < VHELOOP; vh++) { #pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
vout[qh][vh] = {0}; vout[qh][vh] = {0};
#pragma unroll #pragma unroll
for (int w = 0; w < NWARPS; w++) { for(int w = 0; w < NWARPS; w++)
vout[qh][vh] = {
addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]); vout[qh][vh] = addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
} }
} }
} }
if (context_len > partition_size) { if(context_len > partition_size)
scalar_t* out_ptr = out + {
seq_idx * num_heads * max_num_partitions * HEAD_SIZE + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE; partition_idx * HEAD_SIZE;
const int out_num_partitions = max_num_partitions; const int out_num_partitions = max_num_partitions;
bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr); bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
#pragma unroll #pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) { for(int qh = 0; qh < QHLOOP; qh++)
#pragma unroll {
for (int vh = 0; vh < VHELOOP; vh++) { #pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
const int head_size_elem = vh * WARP_SIZE + laneid; const int head_size_elem = vh * WARP_SIZE + laneid;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++)
{
const int head_idx = 4 * qh + i; const int head_idx = 4 * qh + i;
if (head_idx < GQA_RATIO) { if(head_idx < GQA_RATIO)
{
out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
HEAD_SIZE + HEAD_SIZE +
head_size_elem] = vout[qh][vh][i]; head_size_elem] = vout[qh][vh][i];
...@@ -732,31 +802,42 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -732,31 +802,42 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
} }
} // context_len > partition_size } // context_len > partition_size
else { else
{
bit8_t* final_out_ptr_b8; bit8_t* final_out_ptr_b8;
bit16_t* final_out_ptr_b16; bit16_t* final_out_ptr_b16;
if constexpr (std::is_same<OUTT, bit8_t>::value) { if constexpr(std::is_same<OUTT, bit8_t>::value)
{
final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE; final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE;
} else { }
else
{
OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
final_out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr); final_out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
} }
#pragma unroll #pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) { for(int qh = 0; qh < QHLOOP; qh++)
#pragma unroll {
for (int vh = 0; vh < VHELOOP; vh++) { #pragma unroll
for(int vh = 0; vh < VHELOOP; vh++)
{
const int head_size_elem = vh * WARP_SIZE + laneid; const int head_size_elem = vh * WARP_SIZE + laneid;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++)
{
const int head_idx = 4 * qh + i; const int head_idx = 4 * qh + i;
if (head_idx < GQA_RATIO) { if(head_idx < GQA_RATIO)
if constexpr (std::is_same<OUTT, bit8_t>::value) { {
if constexpr(std::is_same<OUTT, bit8_t>::value)
{
const float tmpf = const float tmpf =
out_scale * to_float_b16<scalar_t>(vout[qh][vh][i]); out_scale * to_float_b16<scalar_t>(vout[qh][vh][i]);
const OUTT tmp = hip_fp8(tmpf).data; const OUTT tmp = hip_fp8(tmpf).data;
final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE + final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE +
head_size_elem] = tmp; head_size_elem] = tmp;
} else { }
else
{
final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE + final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE +
head_size_elem] = vout[qh][vh][i]; head_size_elem] = vout[qh][vh][i];
} }
...@@ -769,10 +850,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -769,10 +850,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t,
int PARTITION_SIZE, int NPAR_LOOPS> typename OUTT,
__global__ int HEAD_SIZE,
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( int NUM_THREADS,
int PARTITION_SIZE,
int NPAR_LOOPS>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -781,13 +865,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -781,13 +865,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int max_num_partitions,
const float* __restrict__ fp8_out_scale_ptr)
{
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx]; const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
if (num_partitions == 1) { if(num_partitions == 1)
{
// if num_partitions==1, main kernel will write to out directly, no work in // if num_partitions==1, main kernel will write to out directly, no work in
// reduction kernel // reduction kernel
return; return;
...@@ -801,10 +888,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -801,10 +888,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max num partitions supported is warp_size * NPAR_LOOPS // max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
if (warpid == 0) { if(warpid == 0)
const float* max_logits_ptr = max_logits + {
seq_idx * num_heads * max_num_partitions + const float* max_logits_ptr =
head_idx * max_num_partitions; max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num // valid partition is the last valid partition in case threadid > num
// partitions // partitions
...@@ -812,70 +899,78 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -812,70 +899,78 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
float reg_max_logit[NPAR_LOOPS]; float reg_max_logit[NPAR_LOOPS];
const int last_valid_partition = num_partitions - 1; const int last_valid_partition = num_partitions - 1;
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
{
const int partition_no = i * WARP_SIZE + threadIdx.x; const int partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] = valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition; (partition_no < num_partitions) ? partition_no : last_valid_partition;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
{
reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
} }
float max_logit = reg_max_logit[0]; float max_logit = reg_max_logit[0];
#pragma unroll #pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) { for(int i = 1; i < NPAR_LOOPS; i++)
{
max_logit = fmaxf(max_logit, reg_max_logit[i]); max_logit = fmaxf(max_logit, reg_max_logit[i]);
} }
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for(int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
{
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
} }
const float* exp_sums_ptr = exp_sums + const float* exp_sums_ptr =
seq_idx * num_heads * max_num_partitions + exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
head_idx * max_num_partitions;
float rescaled_exp_sum[NPAR_LOOPS]; float rescaled_exp_sum[NPAR_LOOPS];
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
{
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
{
const int partition_no = i * WARP_SIZE + threadIdx.x; const int partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions) rescaled_exp_sum[i] *=
? expf(reg_max_logit[i] - max_logit) (partition_no < num_partitions) ? expf(reg_max_logit[i] - max_logit) : 0.0f;
: 0.0f;
} }
float global_exp_sum = rescaled_exp_sum[0]; float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll #pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) { for(int i = 1; i < NPAR_LOOPS; i++)
{
global_exp_sum += rescaled_exp_sum[i]; global_exp_sum += rescaled_exp_sum[i];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) { for(int i = 0; i < NPAR_LOOPS; i++)
{
const int partition_no = i * WARP_SIZE + threadIdx.x; const int partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i]; shared_exp_sums[partition_no] = rescaled_exp_sum[i];
} }
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for(int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
{
global_exp_sum += __shfl_xor(global_exp_sum, mask); global_exp_sum += __shfl_xor(global_exp_sum, mask);
} }
if (threadIdx.x == 0) { if(threadIdx.x == 0)
{
shared_global_exp_sum = global_exp_sum; shared_global_exp_sum = global_exp_sum;
} }
} // warpid == 0 } // warpid == 0
const scalar_t* tmp_out_ptr = const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 64; constexpr int MAX_NPAR = 64;
scalar_t tmps[MAX_NPAR]; scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f; const float dzero = 0.0f;
#pragma unroll #pragma unroll
for (int j = 0; j < MAX_NPAR; j++) { for(int j = 0; j < MAX_NPAR; j++)
{
tmps[j] = from_float<scalar_t>(dzero); tmps[j] = from_float<scalar_t>(dzero);
} }
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
...@@ -884,32 +979,32 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -884,32 +979,32 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
constexpr int JCHUNK = 16; constexpr int JCHUNK = 16;
#pragma unroll #pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { for(int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE)
{
// lastj is last valid partition // lastj is last valid partition
const int lastj_offset = const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset]; tmps[idx] = tmp_out_ptr[lastj_offset];
idx++; idx++;
} }
__syncthreads(); __syncthreads();
if (num_partitions > JCHUNK) { if(num_partitions > JCHUNK)
#pragma unroll {
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; #pragma unroll
j += HEAD_SIZE) { for(int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; j += HEAD_SIZE)
const int lastj_offset = {
(j < num_partition_offset) ? j : last_partition_offset; const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset]; tmps[idx] = tmp_out_ptr[lastj_offset];
idx++; idx++;
} }
if (num_partitions > 2 * JCHUNK) { if(num_partitions > 2 * JCHUNK)
#pragma unroll {
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; #pragma unroll
j += HEAD_SIZE) { for(int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE)
const int lastj_offset = {
(j < num_partition_offset) ? j : last_partition_offset; const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset]; tmps[idx] = tmp_out_ptr[lastj_offset];
idx++; idx++;
} }
...@@ -918,64 +1013,77 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -918,64 +1013,77 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
float acc = 0.0f; float acc = 0.0f;
#pragma unroll #pragma unroll
for (int j = 0; j < JCHUNK; j++) { for(int j = 0; j < JCHUNK; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j]; acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
} }
if (num_partitions > JCHUNK) { if(num_partitions > JCHUNK)
#pragma unroll {
for (int j = JCHUNK; j < 2 * JCHUNK; j++) { #pragma unroll
for(int j = JCHUNK; j < 2 * JCHUNK; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j]; acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
} }
if (num_partitions > 2 * JCHUNK) { if(num_partitions > 2 * JCHUNK)
#pragma unroll {
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { #pragma unroll
for(int j = 2 * JCHUNK; j < MAX_NPAR; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j]; acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
} }
} }
} }
for (int p = 1; p < NPAR_LOOPS; p++) { for(int p = 1; p < NPAR_LOOPS; p++)
if (num_partitions > p * MAX_NPAR) { {
if(num_partitions > p * MAX_NPAR)
{
idx = 0; idx = 0;
#pragma unroll #pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; for(int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) { j += HEAD_SIZE)
{
// lastj is last valid partition // lastj is last valid partition
const int lastj_offset = const int lastj_offset = (j < num_partition_offset) ? j : last_partition_offset;
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset]; tmps[idx] = tmp_out_ptr[lastj_offset];
idx++; idx++;
} }
#pragma unroll #pragma unroll
for (int j = 0; j < MAX_NPAR; j++) { for(int j = 0; j < MAX_NPAR; j++)
{
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
} }
} }
} }
const float inv_global_exp_sum = const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f);
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
// const float out_scale = (fp8_out_scale_ptr != nullptr) ? // const float out_scale = (fp8_out_scale_ptr != nullptr) ?
// __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f;
const float out_scale = const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum; acc *= inv_global_exp_sum;
acc *= out_scale; acc *= out_scale;
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
if constexpr (std::is_same<OUTT, bit8_t>::value) { if constexpr(std::is_same<OUTT, bit8_t>::value)
{
out_ptr[threadIdx.x] = hip_fp8(acc).data; out_ptr[threadIdx.x] = hip_fp8(acc).data;
} else { }
else
{
out_ptr[threadIdx.x] = from_float<scalar_t>(acc); out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
} }
} }
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, typename cache_t, template <typename scalar_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, typename cache_t,
int HEAD_SIZE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename OUTT,
int BLOCK_SIZE,
int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
...@@ -983,28 +1091,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( ...@@ -983,28 +1091,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale, int max_ctx_blocks,
const float* __restrict__ fp8_out_scale_ptr) { float k_scale,
float v_scale,
const float* __restrict__ fp8_out_scale_ptr)
{
UNREACHABLE_CODE UNREACHABLE_CODE
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t,
int PARTITION_SIZE, int NPAR_LOOPS> typename OUTT,
__global__ int HEAD_SIZE,
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( int NUM_THREADS,
int PARTITION_SIZE,
int NPAR_LOOPS>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -1014,6 +1131,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -1014,6 +1131,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions, const int max_num_partitions,
const float* __restrict__ fp8_out_scale_ptr){UNREACHABLE_CODE} const float* __restrict__ fp8_out_scale_ptr)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment