Commit f38bd872 authored by flyingdown's avatar flyingdown
Browse files

pa add v prefetch for gemm1

parent 8ebc32aa
...@@ -20,7 +20,6 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -20,7 +20,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#include "static_switch.h" #include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -68,10 +67,10 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -68,10 +67,10 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int REUSE_KV_TIMES = 1, int PARTITION_SIZE = 0,
bool odd_nheads = false, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> =
int PARTITION_SIZE = 0,std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning. 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
...@@ -99,10 +98,10 @@ __device__ void paged_attention_kernel( ...@@ -99,10 +98,10 @@ __device__ void paged_attention_kernel(
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int REUSE_KV_TIMES = 1, int PARTITION_SIZE = 0,
bool odd_nheads = false, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> =
int PARTITION_SIZE = 0,std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning. 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
...@@ -138,7 +137,8 @@ __device__ void paged_attention_kernel( ...@@ -138,7 +137,8 @@ __device__ void paged_attention_kernel(
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size =
USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
...@@ -171,11 +171,11 @@ __device__ void paged_attention_kernel( ...@@ -171,11 +171,11 @@ __device__ void paged_attention_kernel(
// const int lane = thread_idx % WARP_SIZE; // const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE; // const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block int warp_id_vec = threadIdx.x / WARP_SIZE; // warp id in a block
int warp_idx =0; int warp_idx = 0;
asm volatile("v_readfirstlane_b32 %0,%1" asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx) : "=s"(warp_idx)
: "v"(warp_id_vec) : "v"(warp_id_vec)
...@@ -212,16 +212,18 @@ __device__ void paged_attention_kernel( ...@@ -212,16 +212,18 @@ __device__ void paged_attention_kernel(
// const scalar_t* q_ptr = q + seq_idx * q_stride; // const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride; const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec
// #pragma unroll q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; // #pragma unroll
// i += NUM_THREAD_GROUPS) { // for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; // i += NUM_THREAD_GROUPS) {
// q_vecs[thread_group_offset][i] = // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); // q_vecs[thread_group_offset][i] =
// } // *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a // }
// // memory wall right before we use q_vecs // __syncthreads(); // TODO(naed90): possible speedup if this is replaced
// with a
// // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -229,7 +231,8 @@ __device__ void paged_attention_kernel( ...@@ -229,7 +231,8 @@ __device__ void paged_attention_kernel(
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS]; __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]); // float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 *
// NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024]; // __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem); // float* logits = reinterpret_cast<float*>(shared_mem);
...@@ -240,41 +243,52 @@ __device__ void paged_attention_kernel( ...@@ -240,41 +243,52 @@ __device__ void paged_attention_kernel(
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES]; float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX; qk_max[reuse_kv_idx] = -FLT_MAX;
} }
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); const int num_blocks_per_kv =
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES; ((num_queries_per_kv + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES);
const int head_idx_soffset =
(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv +
(blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv; const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv; const int q_boundary = (kv_head_idx + 1) * num_queries_per_kv;
#pragma unroll #pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx; const int head_idx =
head_idx_soffset +
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE + thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx; const int head_idx =
if(!odd_nheads || head_idx < q_boundary) { head_idx_soffset +
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars // blocksparse specific vars
int bs_block_offset; int bs_block_offset;
int q_bs_block_id; int q_bs_block_id;
...@@ -284,8 +298,9 @@ __device__ void paged_attention_kernel( ...@@ -284,8 +298,9 @@ __device__ void paged_attention_kernel(
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = bs_block_offset = (tp_rank * num_heads + head_idx) *
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; blocksparse_head_sliding_step +
1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
...@@ -293,42 +308,48 @@ __device__ void paged_attention_kernel( ...@@ -293,42 +308,48 @@ __device__ void paged_attention_kernel(
1; 1;
} }
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; const int k_bs_block_id =
block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0);
const bool is_local = const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) { if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx =
block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to // NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will // avoid contribution to the sumexp softmax normalizer. This
// not be used at computing sum(softmax*v) as the blocks will be // will not be used at computing sum(softmax*v) as the blocks
// skipped. // will be skipped.
logits[token_idx - start_token_idx] = -FLT_MAX; logits[token_idx - start_token_idx] = -FLT_MAX;
} }
} }
continue; continue;
} }
} }
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in // For example, if the the thread group size is 4, then the first thread
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // in the group has 0, 4, 8, ... th vectors of the key, and the second
// has 1, 5, 9, ... th vectors of the key, and so on. // thread has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) { if (reuse_kv_idx == 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
...@@ -351,18 +372,25 @@ __device__ void paged_attention_kernel( ...@@ -351,18 +372,25 @@ __device__ void paged_attention_kernel(
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs); // group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE +
thread_group_offset],
k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk +=
(alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk; logits[(reuse_kv_idx * partition_size) +
(token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk); qk_max[reuse_kv_idx] =
mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
} }
} }
} }
...@@ -374,12 +402,14 @@ __device__ void paged_attention_kernel( ...@@ -374,12 +402,14 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx; const int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) { if (!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask)); qk_max[reuse_kv_idx] =
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx]; red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
...@@ -388,20 +418,25 @@ __device__ void paged_attention_kernel( ...@@ -388,20 +418,25 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX; qk_max[reuse_kv_idx] =
#pragma unroll lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask)); qk_max[reuse_kv_idx] =
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0); qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]); float val = __expf(logits[(reuse_kv_idx * partition_size) + i] -
qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val; logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val; exp_sum[reuse_kv_idx] += val;
} }
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]); exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(
&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
...@@ -416,7 +451,8 @@ __device__ void paged_attention_kernel( ...@@ -416,7 +451,8 @@ __device__ void paged_attention_kernel(
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx]; *max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx]; *exp_sums_ptr = exp_sum[reuse_kv_idx];
} }
...@@ -437,9 +473,9 @@ __device__ void paged_attention_kernel( ...@@ -437,9 +473,9 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f; accs[reuse_kv_idx][i] = 0.f;
} }
...@@ -453,15 +489,39 @@ __device__ void paged_attention_kernel( ...@@ -453,15 +489,39 @@ __device__ void paged_attention_kernel(
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
V_vec v_vec[2];
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
const int start_row_idx = lane / NUM_V_VECS_PER_ROW;
if (start_row_idx < HEAD_SIZE) {
const int offset = start_row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec[0] = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec[0] = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(
v_quant_vec, kv_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec[0]);
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int j = 0; j < V_VEC_SIZE; j++) {
V_vec v_vec; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { }
// NOTE(woosuk): The block number is stored in int32. However, we cast it to }
// int64 because int32 can lead to overflow when this variable is multiplied }
// by large numbers (e.g., kv_block_stride). #pragma unroll
// For blocksparse attention: skip computation on blocks that are not for (int i = 1; i < NUM_ROWS_PER_THREAD; i++) {
// attended for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES;
reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast
// it to int64 because int32 can lead to overflow when this variable is
// multiplied by large numbers (e.g., kv_block_stride). For blocksparse
// attention: skip computation on blocks that are not attended
// blocksparse specific vars // blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx; const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset; int bs_block_offset;
...@@ -472,8 +532,9 @@ __device__ void paged_attention_kernel( ...@@ -472,8 +532,9 @@ __device__ void paged_attention_kernel(
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = bs_block_offset = (tp_rank * num_heads + head_idx) *
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; blocksparse_head_sliding_step +
1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
...@@ -482,75 +543,119 @@ __device__ void paged_attention_kernel( ...@@ -482,75 +543,119 @@ __device__ void paged_attention_kernel(
} }
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue; continue;
} }
} }
if(!odd_nheads || head_idx < q_boundary) { if (!odd_nheads || head_idx < q_boundary) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(
logits + (reuse_kv_idx * partition_size) +
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride token_idx - start_token_idx));
+ kv_head_idx * kv_head_stride; // scalar_t* logits_vec_ptr =
// reinterpret_cast<scalar_t*>(&logits_vec); for(int i=0;i<8;++i){
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000); // from_float(*(logits_vec_ptr+i), 1000);
// } // }
if(reuse_kv_idx==0) { if (reuse_kv_idx == 0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = start_row_idx + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec[i % 2] = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec[i % 2] =
kv_scale); fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(
v_quant_vec, kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of
// context, we should explicitly zero out the values since they may // the context, we should explicitly zero out the values since
// contain NaNs. See // they may contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr =
reinterpret_cast<scalar_t*>(&v_vec[i % 2]);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] =
token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
// if(threadIdx.x==0){ // if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); // scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec); // scalar_t* logits_vec_ptr =
// for(int i=0;i<8;++i){ // reinterpret_cast<scalar_t*>(&logits_vec); for(int
// i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i])); // printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000); // // from_float(*(v_vec_ptr + i), 1000);
// } // }
// for(int i=0;i<8;++i){ // for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i])); // printf("logits_vec[%d] =
// %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000); // // from_float(*(logits_vec_ptr + i), 1000);
// } // }
// } // }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec); // accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
} }
} }
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec); accs[reuse_kv_idx][i-1] +=
dot(logits_vec, v_vec[(i - 1) % 2]);
}
}
}
// tail
{
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES;
reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) *
blocksparse_head_sliding_step +
1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if (!odd_nheads || head_idx < q_boundary) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(
logits + (reuse_kv_idx * partition_size) +
token_idx - start_token_idx));
accs[reuse_kv_idx][NUM_ROWS_PER_THREAD - 1] +=
dot(logits_vec, v_vec[(NUM_ROWS_PER_THREAD - 1) % 2]);
} }
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx; int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) { if (!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i]; float acc = accs[reuse_kv_idx][i];
...@@ -572,10 +677,12 @@ __device__ void paged_attention_kernel( ...@@ -572,10 +677,12 @@ __device__ void paged_attention_kernel(
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) +
(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx =
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i]; dst[row_idx] = accs[reuse_kv_idx][i];
} }
...@@ -585,10 +692,13 @@ __device__ void paged_attention_kernel( ...@@ -585,10 +692,13 @@ __device__ void paged_attention_kernel(
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE]; const float* src =
&out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) +
warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx =
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx]; accs[reuse_kv_idx][i] += src[row_idx];
} }
...@@ -601,7 +711,8 @@ __device__ void paged_attention_kernel( ...@@ -601,7 +711,8 @@ __device__ void paged_attention_kernel(
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -614,14 +725,11 @@ __device__ void paged_attention_kernel( ...@@ -614,14 +725,11 @@ __device__ void paged_attention_kernel(
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1, int REUSE_KV_TIMES = 1, bool IS_BLOCK_SPARSE, bool odd_nheads = false>
bool IS_BLOCK_SPARSE, __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel(
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
...@@ -652,11 +760,9 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( ...@@ -652,11 +760,9 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES, int PARTITION_SIZE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false> bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -679,18 +785,18 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( ...@@ -679,18 +785,18 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, PARTITION_SIZE>(
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads,
kv_block_stride, kv_head_stride, kv_scale, tp_rank, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale,
blocksparse_head_sliding_step); tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE>
int PARTITION_SIZE> __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel(
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -798,19 +904,20 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( ...@@ -798,19 +904,20 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ ((void*)vllm::paged_attention_v1_kernel< \
BLOCK_SIZE, NUM_THREADS, \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \ REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \ shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL( \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \ (vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
, dim3(grid), dim3(block), shared_mem_size, stream, \ NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ IS_BLOCK_SPARSE, odd_nheads>), \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ dim3(grid), dim3(block), shared_mem_size, stream, out_ptr, query_ptr, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, scale, \
kv_scale, tp_rank, blocksparse_local_blocks, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
blocksparse_vert_stride, blocksparse_block_size, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, kv_scale, \
blocksparse_head_sliding_step); tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ // #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ // vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...@@ -842,8 +949,8 @@ void paged_attention_v1_launcher( ...@@ -842,8 +949,8 @@ void paged_attention_v1_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128; int num_threads = 128;
if(num_heads!=num_kv_heads){ if (num_heads != num_kv_heads) {
num_threads =256; num_threads = 256;
} }
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
...@@ -861,25 +968,36 @@ void paged_attention_v1_launcher( ...@@ -861,25 +968,36 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len =
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] { DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { REUSEKV_SWITCH_V1(num_heads * num_seqs, [&] {
BOOL_SWITCH(
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] { HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] { NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float); int logits_size =
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float); REUSE_KV_TIMES * padded_max_seq_len * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) *
// Keep that in sync with the logic here! head_size * sizeof(float);
// Python-side check in
// vllm.worker.worker._check_if_can_support_max_seq_len Keep
// that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size); if (num_heads == num_kv_heads)
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size)); shared_mem_size = ::max(12 * 1024, shared_mem_size);
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl; // int shared_mem_size = ::max(31*1024, ::max(logits_size,
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs); // outputs_size)); std::cout<<"shared_mem_size =
// "<<shared_mem_size<<std::endl;
dim3 grid((num_heads / num_kv_heads + REUSE_KV_TIMES - 1) /
REUSE_KV_TIMES * num_kv_heads,
1, num_seqs);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); device_of(query));
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
}); });
}); });
...@@ -953,20 +1071,22 @@ void paged_attention_v1( ...@@ -953,20 +1071,22 @@ void paged_attention_v1(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL( \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ (vllm::paged_attention_v2_kernel< \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
, dim3(grid), dim3(block), shared_mem_size, stream, \ IS_BLOCK_SPARSE, REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>), \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \ kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_vert_stride, blocksparse_block_size, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ blocksparse_head_sliding_step); \
PARTITION_SIZE>) \ hipLaunchKernelGGL( \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \ (vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -1010,28 +1130,35 @@ void paged_attention_v2_launcher( ...@@ -1010,28 +1130,35 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] { REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs, [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { BOOL_SWITCH(
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] { HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float); int logits_size = REUSE_KV_TIMES * PARTITION_SIZE * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size =
REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel. // For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs); // dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid; dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads; grid.x = (num_heads / num_kv_heads + REUSE_KV_TIMES - 1) /
REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions; grid.y = max_num_partitions;
grid.z = num_seqs; grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size)); // int shared_mem_size = ::max(1024*32, ::max(logits_size,
// outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel. // For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); int reduce_shared_mem_size =
2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); device_of(query));
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment