Commit 749242a0 authored by flyingdown's avatar flyingdown
Browse files

Revert "pa add v prefetch for gemm1"

This reverts commit f38bd872.
parent dcaabcf7
...@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#include "static_switch.h" #include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -80,7 +81,7 @@ __device__ void paged_attention_kernel( ...@@ -80,7 +81,7 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads] const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads] const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel( ...@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel(
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
...@@ -139,15 +139,15 @@ __device__ void paged_attention_kernel( ...@@ -139,15 +139,15 @@ __device__ void paged_attention_kernel(
// const int lane = thread_idx % WARP_SIZE; // const int lane = thread_idx % WARP_SIZE;
// const int warp_idx = thread_idx / WARP_SIZE; //const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; // warp id in a block int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx = 0; int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1" asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx) : "=s"(warp_idx)
: "v"(warp_id_vec) : "v"(warp_id_vec)
:); :);
// const int head_idx = blockIdx.x; // const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x; // const int num_heads = gridDim.x;
...@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel( ...@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel(
// const scalar_t* q_ptr = q + seq_idx * q_stride; // const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride; const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec __shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; // #pragma unroll
// #pragma unroll // for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; // i += NUM_THREAD_GROUPS) {
// i += NUM_THREAD_GROUPS) { // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; // q_vecs[thread_group_offset][i] =
// q_vecs[thread_group_offset][i] = // *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); // }
// } // __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced // // memory wall right before we use q_vecs
// with a
// // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel( ...@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel(
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS]; __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * // float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024]; // __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem); // float* logits = reinterpret_cast<float*>(shared_mem);
...@@ -211,173 +208,146 @@ __device__ void paged_attention_kernel( ...@@ -211,173 +208,146 @@ __device__ void paged_attention_kernel(
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES]; float qk_max[REUSE_KV_TIMES];
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX; qk_max[reuse_kv_idx] = -FLT_MAX;
} }
const int num_blocks_per_kv = const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
((num_queries_per_kv + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES); const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int head_idx_soffset =
(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv +
(blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv; const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1) * num_queries_per_kv; const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset +
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE + thread_group_offset][i] = q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// memory wall right before we use q_vecs
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset + if(!odd_nheads || head_idx < q_boundary) {
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars // blocksparse specific vars
int bs_block_offset; int bs_block_offset;
int q_bs_block_id; int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size); // blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) * bs_block_offset =
blocksparse_head_sliding_step + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
1; else
else // sliding on kv heads
// sliding on kv heads bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * (-blocksparse_head_sliding_step) +
(-blocksparse_head_sliding_step) + 1;
1; }
} if constexpr (IS_BLOCK_SPARSE) {
if constexpr (IS_BLOCK_SPARSE) { const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const int k_bs_block_id = const bool is_remote =
block_idx * BLOCK_SIZE / blocksparse_block_size; ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_remote = const bool is_local =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
0); if (!is_remote && !is_local) {
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx =
block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This
// will not be used at computing sum(softmax*v) as the blocks
// will be skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread
// in the group has 0, 4, 8, ... th vectors of the key, and the second
// thread has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
if (reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread
// group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE +
thread_group_offset],
k_vecs);
// Add the ALiBi bias if slopes are given.
qk +=
(alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // NOTE(linxihui): assign very large number to skipped tokens to
// NOTE(woosuk): It is required to zero out the masked logits. // avoid contribution to the sumexp softmax normalizer. This will
const bool mask = token_idx >= seq_len; // not be used at computing sum(softmax*v) as the blocks will be
logits[(reuse_kv_idx * partition_size) + // skipped.
(token_idx - start_token_idx)] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = -FLT_MAX;
// Update the max value. }
qk_max[reuse_kv_idx] = }
mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk); continue;
}
}
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
} }
} }
} }
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
}
} }
} }
}
}
// Get the sum of the exp values. // Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f}; float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx; const int head_idx = head_idx_soffset + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) { if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx]; red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
...@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel( ...@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX; #pragma unroll
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0); qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val; logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val; exp_sum[reuse_kv_idx] += val;
} }
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>( exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
...@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel( ...@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel(
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx]; *max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx]; *exp_sums_ptr = exp_sum[reuse_kv_idx];
} }
...@@ -441,11 +405,11 @@ __device__ void paged_attention_kernel( ...@@ -441,11 +405,11 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f; accs[reuse_kv_idx][i] = 0.f;
} }
} }
scalar_t zero_value; scalar_t zero_value;
...@@ -457,230 +421,155 @@ __device__ void paged_attention_kernel( ...@@ -457,230 +421,155 @@ __device__ void paged_attention_kernel(
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
V_vec v_vec[2];
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
const int start_row_idx = lane / NUM_V_VECS_PER_ROW;
if (start_row_idx < HEAD_SIZE) {
const int offset = start_row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec[0] = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec[0] = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(
v_quant_vec, kv_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec[0]);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; V_vec v_vec;
} for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
} }
} if constexpr (IS_BLOCK_SPARSE) {
#pragma unroll int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
for (int i = 1; i < NUM_ROWS_PER_THREAD; i++) { if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
reuse_kv_idx++) { continue;
// NOTE(woosuk): The block number is stored in int32. However, we cast
// it to int64 because int32 can lead to overflow when this variable is
// multiplied by large numbers (e.g., kv_block_stride). For blocksparse
// attention: skip computation on blocks that are not attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) *
blocksparse_head_sliding_step +
1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
} }
if constexpr (IS_BLOCK_SPARSE) { }
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; if(!odd_nheads || head_idx < q_boundary) {
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
continue; + kv_head_idx * kv_head_stride;
}
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
} }
if (!odd_nheads || head_idx < q_boundary) { if (block_idx == num_seq_blocks - 1) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>( // NOTE(woosuk): When v_vec contains the tokens that are out of the
logits + (reuse_kv_idx * partition_size) + // context, we should explicitly zero out the values since they may
token_idx - start_token_idx)); // contain NaNs. See
// scalar_t* logits_vec_ptr = // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// reinterpret_cast<scalar_t*>(&logits_vec); for(int i=0;i<8;++i){ scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// from_float(*(logits_vec_ptr+i), 1000);
// }
if (reuse_kv_idx == 0) {
const int row_idx = start_row_idx + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec[i % 2] = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec[i % 2] =
fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(
v_quant_vec, kv_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of
// the context, we should explicitly zero out the values since
// they may contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr =
reinterpret_cast<scalar_t*>(&v_vec[i % 2]);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr =
// reinterpret_cast<scalar_t*>(&logits_vec); for(int
// i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] =
// %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
} }
accs[reuse_kv_idx][i-1] +=
dot(logits_vec, v_vec[(i - 1) % 2]);
} }
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
} }
}
// tail
{
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES;
reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) *
blocksparse_head_sliding_step +
1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if (!odd_nheads || head_idx < q_boundary) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(
logits + (reuse_kv_idx * partition_size) +
token_idx - start_token_idx));
accs[reuse_kv_idx][NUM_ROWS_PER_THREAD - 1] +=
dot(logits_vec, v_vec[(NUM_ROWS_PER_THREAD - 1) % 2]);
}
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx; int head_idx = head_idx_soffset + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) { if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i]; float acc = accs[reuse_kv_idx][i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[reuse_kv_idx][i] = acc; accs[reuse_kv_idx][i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for // NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem); float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll #pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) { for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { dst[row_idx] = accs[reuse_kv_idx][i];
dst[row_idx] = accs[reuse_kv_idx][i];
}
}
} }
__syncthreads(); }
}
__syncthreads();
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
&out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) +
warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { accs[reuse_kv_idx][i] += src[row_idx];
accs[reuse_kv_idx][i] += src[row_idx];
}
}
} }
__syncthreads();
} }
}
__syncthreads();
}
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -694,18 +583,21 @@ __device__ void paged_attention_kernel( ...@@ -694,18 +583,21 @@ __device__ void paged_attention_kernel(
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1, bool IS_BLOCK_SPARSE, bool odd_nheads = false> int REUSE_KV_TIMES = 1,
__global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel( bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads] const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -716,22 +608,24 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel( ...@@ -716,22 +608,24 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES, int PARTITION_SIZE, bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false> bool odd_nheads = false>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -742,7 +636,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel( ...@@ -742,7 +636,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads] const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads] const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -753,19 +647,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel( ...@@ -753,19 +647,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
PARTITION_SIZE>( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, kv_block_stride, kv_head_stride, kv_scale, tp_rank,
alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel( int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -871,22 +765,21 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel( ...@@ -871,22 +765,21 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel(
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel< \ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ BLOCK_SIZE, NUM_THREADS, \
REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \ KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \ shared_mem_size); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, \ , dim3(grid), dim3(block), shared_mem_size, stream, \
IS_BLOCK_SPARSE, odd_nheads>), \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
dim3(grid), dim3(block), shared_mem_size, stream, out_ptr, query_ptr, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, scale, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ kv_scale, tp_rank, blocksparse_local_blocks, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, kv_scale, \ blocksparse_vert_stride, blocksparse_block_size, \
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ // #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ // vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...@@ -918,8 +811,8 @@ void paged_attention_v1_launcher( ...@@ -918,8 +811,8 @@ void paged_attention_v1_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128; int num_threads = 128;
if (num_heads != num_kv_heads) { if(num_heads!=num_kv_heads){
num_threads = 256; num_threads =256;
} }
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
...@@ -937,42 +830,31 @@ void paged_attention_v1_launcher( ...@@ -937,42 +830,31 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
REUSEKV_SWITCH_V1(num_heads * num_seqs, [&] { BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
BOOL_SWITCH( HEADSIZE_SWITCH(head_size, [&] {
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { NUM_THREADS_SWITCH(num_threads, [&] {
HEADSIZE_SWITCH(head_size, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
NUM_THREADS_SWITCH(num_threads, [&] { constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
OPT_SWITCH(num_heads == num_kv_heads, [&] { int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
int logits_size = // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
REUSE_KV_TIMES * padded_max_seq_len * sizeof(float); // Keep that in sync with the logic here!
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * int shared_mem_size = ::max(logits_size, outputs_size);
head_size * sizeof(float); if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// Python-side check in // int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// vllm.worker.worker._check_if_can_support_max_seq_len Keep // std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// that in sync with the logic here! dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
int shared_mem_size = ::max(logits_size, outputs_size); dim3 block(NUM_THREADS);
if (num_heads == num_kv_heads) const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
shared_mem_size = ::max(12 * 1024, shared_mem_size); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
// int shared_mem_size = ::max(31*1024, ::max(logits_size, LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
// outputs_size)); std::cout<<"shared_mem_size =
// "<<shared_mem_size<<std::endl;
dim3 grid((num_heads / num_kv_heads + REUSE_KV_TIMES - 1) /
REUSE_KV_TIMES * num_kv_heads,
1, num_seqs);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(
device_of(query));
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
});
});
}); });
}); });
}); });
});
});
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...@@ -1040,23 +922,21 @@ void paged_attention_v1( ...@@ -1040,23 +922,21 @@ void paged_attention_v1(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL( \ hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v2_kernel< \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
IS_BLOCK_SPARSE, REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>), \ , dim3(grid), dim3(block), shared_mem_size, stream, \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_head_sliding_step); \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL( \ PARTITION_SIZE>) \
(vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ , dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
PARTITION_SIZE>), \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \ max_num_partitions);
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
...@@ -1099,39 +979,32 @@ void paged_attention_v2_launcher( ...@@ -1099,39 +979,32 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs, [&] { REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH( BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { HEADSIZE_SWITCH(head_size, [&] {
HEADSIZE_SWITCH(head_size, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int logits_size = REUSE_KV_TIMES * PARTITION_SIZE * sizeof(float); int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
int outputs_size =
REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs); dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
dim3 grid; grid.y = max_num_partitions;
grid.x = (num_heads / num_kv_heads + REUSE_KV_TIMES - 1) / grid.z = num_seqs;
REUSE_KV_TIMES * num_kv_heads; // int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
grid.y = max_num_partitions; int shared_mem_size = ::max(logits_size, outputs_size);
grid.z = num_seqs; // For paged attention v2 reduce kernel.
// int shared_mem_size = ::max(1024*32, ::max(logits_size, dim3 reduce_grid(num_heads, num_seqs);
// outputs_size)); int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
int shared_mem_size = ::max(logits_size, outputs_size); dim3 block(NUM_THREADS);
// For paged attention v2 reduce kernel. const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
dim3 reduce_grid(num_heads, num_seqs); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
int reduce_shared_mem_size = LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(
device_of(query));
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
});
});
}); });
});
});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment