Unverified Commit dd5fa7e0 authored by Hosang's avatar Hosang Committed by GitHub
Browse files

[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)


Signed-off-by: default avatarHosang Yoon <hosang.yoon@amd.com>
parent 2b161045
...@@ -84,7 +84,10 @@ def main( ...@@ -84,7 +84,10 @@ def main(
if version == "v2": if version == "v2":
if current_platform.is_rocm(): if current_platform.is_rocm():
global PARTITION_SIZE global PARTITION_SIZE
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM if not args.custom_paged_attn and not current_platform.is_navi():
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
...@@ -159,6 +162,7 @@ def main( ...@@ -159,6 +162,7 @@ def main(
scale, scale,
block_tables, block_tables,
seq_lens, seq_lens,
None,
block_size, block_size,
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
......
...@@ -30,6 +30,14 @@ ...@@ -30,6 +30,14 @@
#define __HIP__GFX9__ #define __HIP__GFX9__
#endif #endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG) #if defined(NDEBUG)
#undef NDEBUG #undef NDEBUG
#include <assert.h> #include <assert.h>
...@@ -43,7 +51,7 @@ ...@@ -43,7 +51,7 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__GFX9__) // TODO: Add NAVI support #if defined(__HIP__GFX9__)
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
...@@ -1482,198 +1490,1697 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -1482,198 +1490,1697 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
} }
} }
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #elif defined(__HIP__GFX11__)
// clang-format off using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, using bit16_t = uint16_t;
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
int GQA_RATIO> typedef bit16x4 _B16x4;
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t;
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] union b16x8_u {
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] bit16x8 u16x8;
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] _B16x4 xy[2];
const int num_kv_heads, };
const float scale, typedef b16x8_u _B16x8;
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] using bit16x16 =
const int* __restrict__ query_start_loc_ptr, // [num_seqs] __attribute__((__vector_size__(16 * sizeof(uint16_t)))) uint16_t;
const int max_num_blocks_per_seq, union b16x16_u {
const float* __restrict__ alibi_slopes, // [num_heads] bit16x16 u16x16;
const int q_stride, _B16x8 xy[2];
const int kv_block_stride, };
const int kv_head_stride, typedef b16x16_u _B16x16;
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] using _B8x8 = uint2;
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] using bit8_t = uint8_t;
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) { typedef struct _B8x16 {
UNREACHABLE_CODE _B8x8 xy[2];
} _B8x16;
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x16& inpA,
const bit16x16& inpB,
const floatx8& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(inpA, inpB, inpC);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(inpA, inpB, inpC);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
union h2cvt {
__half2 h2[4];
_B16x8 b16x8;
} u;
u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1]));
u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3]));
u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5]));
u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
union b2cvt {
__hip_bfloat162 b2[4];
_B16x8 b16x8;
} u;
u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1]));
u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3]));
u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5]));
u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
} }
// clang-format off
template <typename scalar_t, typename cache_t, template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE, vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
int GQA_RATIO>
__global__ __global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const int num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) { int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE // clang-format on
} constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane2id = laneid % 2;
const int lane4id = laneid % 4;
const int lane16id = laneid % 16;
const int rowid = laneid / 16;
// Grid: (num_heads, num_seqs). const int seq_idx = blockIdx.x;
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS, // NOTE queries with sequence len > 1 are prefills and taken care by another
int PARTITION_SIZE, int NPAR_LOOPS> // kernel.
__global__ if (query_start_loc_ptr != nullptr &&
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] return;
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] }
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
}
// clang-format on
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support const int partition_idx = blockIdx.y;
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ constexpr int T_PAR_SIZE = 256; // token partition size set to 256
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ const int max_num_partitions = gridDim.y;
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ const int context_len = context_lens[seq_idx]; // length of a seq
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE, const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, // exit if partition is out of context for seq
bool ALIBI_ENABLED> if (partition_start_token_idx >= context_len) {
void paged_attention_custom_launcher( return;
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, }
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used. constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2);
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: alibi_slopes is optional. __shared__ float shared_qk_max[NWARPS][16 + 1];
const float* alibi_slopes_ptr = __shared__ float shared_exp_sum[NWARPS][16 + 1];
alibi_slopes // shared_logits is used for multiple purposes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) __shared__ _B16x16 shared_logits[NWARPS][2][16][2];
: nullptr;
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); // for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes,
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); // 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); constexpr int ROWS_PER_WARP =
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); WARP_SIZE / 16 / 2; // rows refers to 16 lanes; refer dpp terminology
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr()); constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD =
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr()); 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types
int* block_tables_ptr = block_tables.data_ptr<int>(); constexpr int QKHE_PER_FETCH =
int* context_lens_ptr = context_lens.data_ptr<int>(); CONTIGUOUS_KV_ELEMS_16B_LOAD *
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); ROWS_PER_WARP; // each fetch across a warp fetches these many elements
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across
// NOTE: fp8_out_scale is optional. // warp
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); _B16x16 Qlocal[QKHELOOP / 2]; // note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int NTHR = 256; constexpr int TOKENS_PER_WARP =
dim3 grid(num_seqs, max_num_partitions, num_kv_heads); T_PAR_SIZE /
dim3 block(NTHR); NWARPS; // sub partition of tokens per warp for qk calculation
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); constexpr int TLOOP =
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TOKENS_PER_WARP /
16; // each wmma16x16x16 instruction processes 16 tokens
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 _B16x16 Klocal[TLOOP]
switch (gqa_ratio) { [QKHELOOP / 2]; // can be interpreted as B8x16 for 8 bit types
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA4(1); const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
break; const int wg_start_kv_head_idx = blockIdx.z;
case 2: const int total_num_heads = gridDim.z * GQA_RATIO;
LAUNCH_CUSTOM_ATTENTION_MFMA4(2);
break; // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
case 3: // each wmma takes QH16xT16x16HE across warp
LAUNCH_CUSTOM_ATTENTION_MFMA4(3); // repeat wmma across QKHELOOP dimension
break; // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
case 4: // across 2 rows x 8 tokens per lane
LAUNCH_CUSTOM_ATTENTION_MFMA4(4);
break; const int64_t query_start_off = static_cast<int64_t>(
case 5: query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break; if (GQA_RATIO == 1) {
case 6: const int local_qhead_idx = lane16id % GQA_RATIO;
LAUNCH_CUSTOM_ATTENTION_MFMA16(6); const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
break; const scalar_t* q_ptr =
case 7: q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
if (lane16id < GQA_RATIO) {
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH * 2;
const _B16x16* q_fetch_ptr_32B =
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
}
}
} else {
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 2 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr =
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
const scalar_t* q_fetch_ptr = q_ptr + qhead_element;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
_B16x8 tmp = *q_fetch_ptr_16B;
const int offset1 =
lane16id /
2; // 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits[offset1][lane2id][local_qhead_idx][0].xy[0] = tmp;
}
__syncthreads();
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
Qlocal[qkhe_depth].xy[0] =
shared_logits[qkhe_depth][0][lane16id % GQA_RATIO][0].xy[0];
Qlocal[qkhe_depth].xy[1] =
shared_logits[qkhe_depth][1][lane16id % GQA_RATIO][0].xy[0];
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
int kphysical_block_number[TLOOP];
// fetch k physical block numbers
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
constexpr int KX = 16 / sizeof(cache_t);
const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
const int row_head_elem = 0;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int64_t kblock_number =
static_cast<int64_t>(kphysical_block_number[token_depth]);
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2;
const _B16x8* k_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(k_fetch_ptr);
Klocal[token_depth][qkhe_depth / 2].xy[qkhe_depth % 2] = *k_fetch_ptr_16B;
}
}
constexpr int VTOKENS_PER_LANE =
TOKENS_PER_WARP / ROWS_PER_WARP; // 32/1 = 32 vtokens per lane
constexpr int VBLOCKS_PER_LANE = 2; // assumes block size >=16
constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps
constexpr int VTLANELOOP = DIVIDE_ROUND_UP(
VTOKENS_PER_LANE,
CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes
// minimum block size is 16
constexpr int VHELOOP = DIVIDE_ROUND_UP(
(HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each
// wmma instr works on 16 head elements
int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE];
// fetch v physical block numbers
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE;
vblock_depth++) {
const int vlocal_token_idx =
vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP +
vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
}
_B16x16 Vlocal[VTLOOP][VHELOOP]
[VTLANELOOP / 2]; // this can be interpreted as B8x16 too
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
// v fetches are 16head elems across lanes x (16x2) tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id;
const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE;
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int64_t vblock_number = static_cast<int64_t>(
vphysical_block_number[vtoken_depth]
[vfetch_depth / VBLOCKS_PER_LANE]);
const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride);
const cache_t* v_fetch_ptr =
v_ptr3 +
(vfetch_depth % VBLOCKS_PER_LANE) * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const _B16x8* v_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(v_fetch_ptr);
Vlocal[vtoken_depth][vhe_depth][vfetch_depth / 2].xy[vfetch_depth % 2] =
*v_fetch_ptr_16B;
}
}
}
floatx8 dout[TLOOP];
// qk wmma
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] = {0};
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x16, Qlocal[qkhe_depth].u16x16,
dout[token_depth]);
}
dout[token_depth] *= scale;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float qk_max = -FLT_MAX;
float exp_sum = 0.0f;
const int qkout_token_idx =
partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
? dout[token_depth][i]
: -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16));
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
exp_sum += tmp;
}
}
exp_sum += __shfl_xor(exp_sum, 16);
__syncthreads();
if (laneid < 16) {
shared_qk_max[warpid][lane16id] = qk_max;
shared_exp_sum[warpid][lane16id] = exp_sum;
}
__syncthreads();
// calculate partition qk_max and exp_sum
float partition_qk_max = -FLT_MAX;
float warp_qk_max_exp[NWARPS];
float partition_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = shared_qk_max[w][lane16id];
partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]);
}
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max);
partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w];
}
const float inv_sum_scale =
__fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid];
__syncthreads();
// write logits to shared mem
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
shared_logits[warpid][token_depth][lane16id][0].xy[rowid] =
from_floatx8<scalar_t>(dout[token_depth]);
}
__syncthreads();
_B16x8 swp_buf[TLOOP][2];
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
swp_buf[token_depth][0] =
shared_logits[warpid][token_depth][lane16id][0].xy[0];
swp_buf[token_depth][1] =
shared_logits[warpid][token_depth][lane16id][0].xy[1];
}
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
shared_logits[warpid][token_depth][lane16id][0].xy[rowid].u16x8[i] =
swp_buf[token_depth][i % 2].u16x8[4 * rowid + (i / 2)];
}
}
// write out partition max_logits and exp_sum
if (threadIdx.x < GQA_RATIO) {
const int qhead_idx = lane16id;
const int offset = seq_idx * total_num_heads * max_num_partitions +
(wg_start_head_idx + qhead_idx) * max_num_partitions +
partition_idx;
max_logits[offset] = partition_qk_max;
exp_sums[offset] = partition_exp_sum;
}
__syncthreads();
_B16x8 outelems[VHELOOP];
// Softmax V wmma
// v layout: 16he across lanes x (16x2) tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
floatx8 tmp_out = {0};
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP / 2;
vfetch_depth++) {
const int offset = vfetch_depth;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x16,
shared_logits[vtoken_depth][offset][lane16id][0].u16x16, tmp_out);
}
}
outelems[vhe_depth] = from_floatx8<scalar_t>(tmp_out);
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid] =
outelems[vhe_depth]; // lane16 id head dimension; rowid head element
// dimension
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
swp_buf[vhe_depth][0] = shared_logits[warpid][vhe_depth][lane16id][0].xy[0];
swp_buf[vhe_depth][1] = shared_logits[warpid][vhe_depth][lane16id][0].xy[1];
}
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid].u16x8[i] =
swp_buf[vhe_depth][i % 2].u16x8[4 * rowid + (i / 2)];
}
}
__syncthreads();
// write to tmp_out with coalesced writes after reading from shared mem
if (warpid == 0) {
_B16x8 vout[GQA_RATIO2];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const int head_elem_idx = lane16id * 8;
if (head_elem_idx < HEAD_SIZE) {
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
const int offset1 = (head_elem_idx / 16) % NWARPS;
const int offset2 = head_elem_idx / 16 / NWARPS;
const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row
vout[h] =
shared_logits[offset1][offset2][local_head_idx][0].xy[offset3];
}
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult +
partition_idx * HEAD_SIZE;
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
if (local_head_idx < GQA_RATIO) {
const int out_head_idx = wg_start_head_idx + local_head_idx;
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
scalar_t* out_ptr3 = out_ptr2 + head_elem_idx;
_B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3);
*out_ptr_B16x8 = vout[h];
}
}
}
}
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
int valid_partition[NPAR_LOOPS];
float reg_max_logit[NPAR_LOOPS];
const int last_valid_partition = num_partitions - 1;
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition;
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
}
float max_logit = reg_max_logit[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
max_logit = fmaxf(max_logit, reg_max_logit[i]);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float rescaled_exp_sum[NPAR_LOOPS];
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit)
: 0.0f;
}
float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
global_exp_sum += rescaled_exp_sum[i];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 32;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
for (int p = 1; p < NPAR_LOOPS; p++) {
if (num_partitions > p * MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
}
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#elif defined(__HIP__GFX12__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4;
using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t;
union b16x8_u {
bit16x8 u16x8;
_B16x4 xy[2];
};
typedef b16x8_u _B16x8;
using _B8x8 = uint2;
using bit8_t = uint8_t;
typedef struct _B8x16 {
_B8x8 xy[2];
} _B8x16;
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA,
const bit16x8& inpB,
const floatx8& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(inpA, inpB, inpC);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(inpA, inpB, inpC);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float_b16(const bit16_t& inp) {
union tmpcvt {
bit16_t u;
_Float16 f;
__hip_bfloat16 b;
} t16;
t16.u = inp;
if constexpr (std::is_same<T, _Float16>::value) {
return (float)t16.f;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(t16.b);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
union h2cvt {
__half2 h2[4];
_B16x8 b16x8;
} u;
u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1]));
u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3]));
u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5]));
u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
union b2cvt {
__hip_bfloat162 b2[4];
_B16x8 b16x8;
} u;
u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1]));
u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3]));
u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5]));
u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7]));
return u.b16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
// clang-format on
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane2id = laneid % 2;
const int lane4id = laneid % 4;
const int lane16id = laneid % 16;
const int rowid = laneid / 16;
const int seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int partition_idx = blockIdx.y;
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
return;
}
constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2);
__shared__ float shared_qk_max[NWARPS][16 + 1];
__shared__ float shared_exp_sum[NWARPS][16 + 1];
// shared_logits is used for multiple purposes
__shared__ _B16x8 shared_logits[NWARPS][2][16][2];
// for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes,
// 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of
// warp
constexpr int ROWS_PER_WARP =
WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology
constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD =
16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types
constexpr int QKHE_PER_FETCH =
CONTIGUOUS_KV_ELEMS_16B_LOAD *
ROWS_PER_WARP; // each fetch across a warp fetches these many elements
constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across
// warp
_B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);
constexpr int TOKENS_PER_WARP =
T_PAR_SIZE /
NWARPS; // sub partition of tokens per warp for qk calculation
constexpr int TLOOP =
TOKENS_PER_WARP /
16; // each wmma16x16x16 instruction processes 16 tokens
_B16x8 Klocal[TLOOP]
[QKHELOOP]; // can be interpreted as B8x16 for 8 bit types
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int total_num_heads = gridDim.z * GQA_RATIO;
// for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each wmma takes QH16xT16x16HE across warp
// repeat wmma across QKHELOOP dimension
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
// across 2 rows x 8 tokens per lane
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
if (GQA_RATIO == 1) {
const int local_qhead_idx = lane16id % GQA_RATIO;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr = q + query_start_off * q_stride +
global_qhead_idx * HEAD_SIZE +
rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
if (lane16id < GQA_RATIO) {
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
}
}
} else {
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 2 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr =
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
const scalar_t* q_fetch_ptr = q_ptr + qhead_element;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
_B16x8 tmp = *q_fetch_ptr_16B;
const int offset1 =
lane16id /
2; // 16 contiguous chunks of head elems are spread across 8x2lanes
shared_logits[offset1][lane2id][local_qhead_idx][0] = tmp;
}
__syncthreads();
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
Qlocal[qkhe_depth] =
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0];
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
int kphysical_block_number[TLOOP];
// fetch k physical block numbers
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
constexpr int KX = 16 / sizeof(cache_t);
const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int64_t kblock_number =
static_cast<int64_t>(kphysical_block_number[token_depth]);
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2;
const _B16x8* k_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(k_fetch_ptr);
Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B;
}
}
constexpr int VTOKENS_PER_LANE =
TOKENS_PER_WARP / ROWS_PER_WARP; // 32/2 = 16 vtokens per lane
constexpr int VBLOCKS_PER_LANE = 1; // assumes block size >=16
constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps
constexpr int VTLANELOOP = DIVIDE_ROUND_UP(
VTOKENS_PER_LANE,
CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes
// minimum block size is 16
constexpr int VHELOOP = DIVIDE_ROUND_UP(
(HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each
// wmma instr works on 16 head elements
int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE];
// fetch v physical block numbers
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE;
vblock_depth++) {
const int vlocal_token_idx =
vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP +
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
}
_B16x8 Vlocal[VTLOOP][VHELOOP]
[VTLANELOOP]; // this can be interpreted as B8x16 too
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride +
((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE);
// v fetches are 16head elems across lanes x 16 tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id;
const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE;
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int vblock_depth = 0;
const int64_t vblock_number = static_cast<int64_t>(
vphysical_block_number[vtoken_depth][vblock_depth]);
const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride);
const cache_t* v_fetch_ptr =
v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const _B16x8* v_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(v_fetch_ptr);
Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B;
}
}
}
floatx8 dout[TLOOP];
// qk wmma
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] = {0};
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8,
dout[token_depth]);
}
dout[token_depth] *= scale;
}
// calculate qk_max and exp_sum per warp and write to shared memory
float qk_max = -FLT_MAX;
float exp_sum = 0.0f;
const int qkout_token_idx =
partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 8;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp =
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16));
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + i < context_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
exp_sum += tmp;
}
}
exp_sum += __shfl_xor(exp_sum, 16);
__syncthreads();
if (laneid < 16) {
shared_qk_max[warpid][lane16id] = qk_max;
shared_exp_sum[warpid][lane16id] = exp_sum;
}
__syncthreads();
// calculate partition qk_max and exp_sum
float partition_qk_max = -FLT_MAX;
float warp_qk_max_exp[NWARPS];
float partition_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = shared_qk_max[w][lane16id];
partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]);
}
for (int w = 0; w < NWARPS; w++) {
warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max);
partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w];
}
const float inv_sum_scale =
__fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid];
__syncthreads();
// write logits to shared mem
#pragma unroll
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx8<scalar_t>(dout[token_depth]);
}
// write out partition max_logits and exp_sum
if (threadIdx.x < GQA_RATIO) {
const int qhead_idx = lane16id;
const int offset = seq_idx * total_num_heads * max_num_partitions +
(wg_start_head_idx + qhead_idx) * max_num_partitions +
partition_idx;
max_logits[offset] = partition_qk_max;
exp_sums[offset] = partition_exp_sum;
}
__syncthreads();
_B16x8 outelems[VHELOOP];
// Softmax V wmma
// v layout: 16he across lanes x 16 tokens per lane
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
floatx8 tmp_out = {0};
for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int offset = rowid * VTLANELOOP + vfetch_depth;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8,
shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8,
tmp_out);
}
}
outelems[vhe_depth] = from_floatx8<scalar_t>(tmp_out);
}
__syncthreads();
#pragma unroll
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
shared_logits[warpid][vhe_depth][lane16id][rowid] =
outelems[vhe_depth]; // lane16 id head dimension; rowid head element
// dimension
}
__syncthreads();
// write to tmp_out with coalesced writes after reading from shared mem
if (warpid == 0) {
_B16x8 vout[GQA_RATIO2];
// each lane writes out 16Bytes of tmp_out along head elem dimension
const int head_elem_idx = lane16id * 8;
if (head_elem_idx < HEAD_SIZE) {
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
const int offset1 = (head_elem_idx / 16) % NWARPS;
const int offset2 = head_elem_idx / 16 / NWARPS;
const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row
vout[h] = shared_logits[offset1][offset2][local_head_idx][offset3];
}
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult +
partition_idx * HEAD_SIZE;
for (int h = 0; h < GQA_RATIO2; h++) {
const int local_head_idx = 2 * h + rowid;
if (local_head_idx < GQA_RATIO) {
const int out_head_idx = wg_start_head_idx + local_head_idx;
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
scalar_t* out_ptr3 = out_ptr2 + head_elem_idx;
_B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3);
*out_ptr_B16x8 = vout[h];
}
}
}
}
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS
__shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
int valid_partition[NPAR_LOOPS];
float reg_max_logit[NPAR_LOOPS];
const int last_valid_partition = num_partitions - 1;
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition;
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
reg_max_logit[i] = max_logits_ptr[valid_partition[i]];
}
float max_logit = reg_max_logit[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
max_logit = fmaxf(max_logit, reg_max_logit[i]);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float rescaled_exp_sum[NPAR_LOOPS];
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit)
: 0.0f;
}
float global_exp_sum = rescaled_exp_sum[0];
#pragma unroll
for (int i = 1; i < NPAR_LOOPS; i++) {
global_exp_sum += rescaled_exp_sum[i];
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 32;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
for (int p = 1; p < NPAR_LOOPS; p++) {
if (num_partitions > p * MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR];
}
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#else
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, typename OUTT, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE, int NPAR_LOOPS>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
OUTT* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
}
// clang-format on
#endif
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int NTHR = 256;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// mfma4 kernel is faster than mfma16 for gqa_ratio <= 4
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA4(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION_MFMA4(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION_MFMA4(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION_MFMA4(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION_MFMA16(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION_MFMA16(7); LAUNCH_CUSTOM_ATTENTION_MFMA16(7);
break; break;
case 8: case 8:
...@@ -1744,13 +3251,195 @@ void paged_attention_custom_launcher( ...@@ -1744,13 +3251,195 @@ void paged_attention_custom_launcher(
} }
} }
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: Navi does not support alibi_slopes.
const float* alibi_slopes_ptr = nullptr;
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: Navi does not support fp8.
const auto fp8_out_scale_ptr = nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
constexpr int NTHR = 256;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION_MFMA16(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION_MFMA16(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION_MFMA16(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION_MFMA16(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION_MFMA16(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION_MFMA16(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION_MFMA16(7);
break;
case 8:
LAUNCH_CUSTOM_ATTENTION_MFMA16(8);
break;
case 9:
LAUNCH_CUSTOM_ATTENTION_MFMA16(9);
break;
case 10:
LAUNCH_CUSTOM_ATTENTION_MFMA16(10);
break;
case 11:
LAUNCH_CUSTOM_ATTENTION_MFMA16(11);
break;
case 12:
LAUNCH_CUSTOM_ATTENTION_MFMA16(12);
break;
case 13:
LAUNCH_CUSTOM_ATTENTION_MFMA16(13);
break;
case 14:
LAUNCH_CUSTOM_ATTENTION_MFMA16(14);
break;
case 15:
LAUNCH_CUSTOM_ATTENTION_MFMA16(15);
break;
case 16:
LAUNCH_CUSTOM_ATTENTION_MFMA16(16);
break;
default:
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
const int warp_size = 32;
const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, warp_size);
// reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256
// (partition size) = 128K context length
switch (npar_loops) {
case 1:
LAUNCH_CUSTOM_REDUCTION(1);
break;
case 2:
LAUNCH_CUSTOM_REDUCTION(2);
break;
case 3:
LAUNCH_CUSTOM_REDUCTION(3);
break;
case 4:
LAUNCH_CUSTOM_REDUCTION(4);
break;
case 5:
LAUNCH_CUSTOM_REDUCTION(5);
break;
case 6:
LAUNCH_CUSTOM_REDUCTION(6);
break;
case 7:
LAUNCH_CUSTOM_REDUCTION(7);
break;
case 8:
LAUNCH_CUSTOM_REDUCTION(8);
break;
case 9:
LAUNCH_CUSTOM_REDUCTION(9);
break;
case 10:
LAUNCH_CUSTOM_REDUCTION(10);
break;
case 11:
LAUNCH_CUSTOM_REDUCTION(11);
break;
case 12:
LAUNCH_CUSTOM_REDUCTION(12);
break;
case 13:
LAUNCH_CUSTOM_REDUCTION(13);
break;
case 14:
LAUNCH_CUSTOM_REDUCTION(14);
break;
case 15:
LAUNCH_CUSTOM_REDUCTION(15);
break;
case 16:
LAUNCH_CUSTOM_REDUCTION(16);
break;
default:
TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops);
break;
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \ PSIZE, ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ if (!is_navi) { \
PSIZE, ALIBI_ENABLED>( \ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); max_context_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \ OUTT, PSIZE) \
...@@ -1807,6 +3496,24 @@ void paged_attention_custom_launcher( ...@@ -1807,6 +3496,24 @@ void paged_attention_custom_launcher(
break; \ break; \
} }
bool is_navi_gpu() {
static bool is_cached = false;
static bool result;
if (!is_cached) {
int device_id;
hipDeviceProp_t deviceProp;
hipGetDevice(&device_id);
hipGetDeviceProperties(&deviceProp, device_id);
std::string arch = deviceProp.gcnArchName;
result = arch.find("gfx11") == 0 || arch.find("gfx12") == 0;
is_cached = true;
}
return result;
}
// clang-format off // clang-format off
void paged_attention( void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
...@@ -1827,6 +3534,8 @@ void paged_attention( ...@@ -1827,6 +3534,8 @@ void paged_attention(
torch::Tensor& v_scale, torch::Tensor& v_scale,
const std::optional<torch::Tensor>& fp8_out_scale) { const std::optional<torch::Tensor>& fp8_out_scale) {
// clang-format on // clang-format on
bool is_navi = is_navi_gpu();
const int head_size = query.size(2); const int head_size = query.size(2);
if (kv_cache_dtype == "auto") { if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) { if (query.dtype() == at::ScalarType::Half) {
......
...@@ -148,6 +148,11 @@ def test_paged_attention( ...@@ -148,6 +148,11 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))): or (version == "rocm" and head_size not in (64, 128))):
pytest.skip() pytest.skip()
if (version == "rocm" and current_platform.is_navi()
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
pytest.skip()
global PARTITION_SIZE global PARTITION_SIZE
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
...@@ -275,6 +280,7 @@ def test_paged_attention( ...@@ -275,6 +280,7 @@ def test_paged_attention(
scale, scale,
block_tables, block_tables,
seq_lens, seq_lens,
None,
block_size, block_size,
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
...@@ -286,7 +292,7 @@ def test_paged_attention( ...@@ -286,7 +292,7 @@ def test_paged_attention(
opcheck(torch.ops._rocm_C.paged_attention, opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query, (output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, seq_lens, None, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale), kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
......
...@@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window) decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
...@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode( ...@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size, block_size,
num_queries_per_kv, num_queries_per_kv,
max_seq_len, sliding_window) max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
if use_custom: if use_custom:
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
......
...@@ -102,27 +102,43 @@ def on_mi250_mi300() -> bool: ...@@ -102,27 +102,43 @@ def on_mi250_mi300() -> bool:
@cache @cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, def use_rocm_custom_paged_attention(
block_size: int, gqa_ratio: int, qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int, max_seq_len: int,
sliding_window: int) -> bool: sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0 if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1)) or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 and (gqa_ratio >= 1 and gqa_ratio <= 16)
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER)) and envs.VLLM_ROCM_USE_AITER))
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
...@@ -362,3 +378,7 @@ class RocmPlatform(Platform): ...@@ -362,3 +378,7 @@ class RocmPlatform(Platform):
def get_cu_count(cls, device_id: int = 0) -> int: def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties( return torch.cuda.get_device_properties(
device_id).multi_processor_count device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment