Commit 749242a0 authored by flyingdown's avatar flyingdown
Browse files

Revert "pa add v prefetch for gemm1"

This reverts commit f38bd872.
parent dcaabcf7
...@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -20,6 +20,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#include "static_switch.h" #include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel( ...@@ -105,8 +106,7 @@ __device__ void paged_attention_kernel(
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
...@@ -139,11 +139,11 @@ __device__ void paged_attention_kernel( ...@@ -139,11 +139,11 @@ __device__ void paged_attention_kernel(
// const int lane = thread_idx % WARP_SIZE; // const int lane = thread_idx % WARP_SIZE;
// const int warp_idx = thread_idx / WARP_SIZE; //const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; // warp id in a block int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx = 0; int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1" asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx) : "=s"(warp_idx)
: "v"(warp_id_vec) : "v"(warp_id_vec)
...@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel( ...@@ -180,18 +180,16 @@ __device__ void paged_attention_kernel(
// const scalar_t* q_ptr = q + seq_idx * q_stride; // const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride; const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec __shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; // #pragma unroll
// #pragma unroll // for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; // i += NUM_THREAD_GROUPS) {
// i += NUM_THREAD_GROUPS) { // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; // q_vecs[thread_group_offset][i] =
// q_vecs[thread_group_offset][i] = // *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); // }
// } // __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced // // memory wall right before we use q_vecs
// with a
// // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel( ...@@ -199,8 +197,7 @@ __device__ void paged_attention_kernel(
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS]; __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * // float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024]; // __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem); // float* logits = reinterpret_cast<float*>(shared_mem);
...@@ -211,52 +208,41 @@ __device__ void paged_attention_kernel( ...@@ -211,52 +208,41 @@ __device__ void paged_attention_kernel(
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES]; float qk_max[REUSE_KV_TIMES];
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX; qk_max[reuse_kv_idx] = -FLT_MAX;
} }
const int num_blocks_per_kv = const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
((num_queries_per_kv + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES); const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int head_idx_soffset =
(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv +
(blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv; const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1) * num_queries_per_kv; const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset +
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE + thread_group_offset][i] = q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// memory wall right before we use q_vecs
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
head_idx_soffset + if(!odd_nheads || head_idx < q_boundary) {
reuse_kv_idx; // blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars // blocksparse specific vars
int bs_block_offset; int bs_block_offset;
int q_bs_block_id; int q_bs_block_id;
...@@ -266,9 +252,8 @@ __device__ void paged_attention_kernel( ...@@ -266,9 +252,8 @@ __device__ void paged_attention_kernel(
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) * bs_block_offset =
blocksparse_head_sliding_step + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
...@@ -276,48 +261,42 @@ __device__ void paged_attention_kernel( ...@@ -276,48 +261,42 @@ __device__ void paged_attention_kernel(
1; 1;
} }
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
0);
const bool is_local = const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) { if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to // NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This // avoid contribution to the sumexp softmax normalizer. This will
// will not be used at computing sum(softmax*v) as the blocks // not be used at computing sum(softmax*v) as the blocks will be
// will be skipped. // skipped.
logits[token_idx - start_token_idx] = -FLT_MAX; logits[token_idx - start_token_idx] = -FLT_MAX;
} }
} }
continue; continue;
} }
} }
const float alibi_slope = const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread // For example, if the the thread group size is 4, then the first thread in
// in the group has 0, 4, 8, ... th vectors of the key, and the second // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// thread has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
if (reuse_kv_idx == 0) { if(reuse_kv_idx == 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
...@@ -340,25 +319,18 @@ __device__ void paged_attention_kernel( ...@@ -340,25 +319,18 @@ __device__ void paged_attention_kernel(
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread // This includes a reduction across the threads in the same thread group.
// group. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[reuse_kv_idx * THREAD_GROUP_SIZE +
thread_group_offset],
k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
(alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
(token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
} }
} }
} }
...@@ -370,14 +342,12 @@ __device__ void paged_attention_kernel( ...@@ -370,14 +342,12 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx; const int head_idx = head_idx_soffset + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) { if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx]; red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
...@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel( ...@@ -386,25 +356,20 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX; #pragma unroll
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
fmaxf(qk_max[reuse_kv_idx],
VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0); qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val; logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val; exp_sum[reuse_kv_idx] += val;
} }
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>( exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
...@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel( ...@@ -419,8 +384,7 @@ __device__ void paged_attention_kernel(
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx]; *max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx]; *exp_sums_ptr = exp_sum[reuse_kv_idx];
} }
...@@ -441,9 +405,9 @@ __device__ void paged_attention_kernel( ...@@ -441,9 +405,9 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f; accs[reuse_kv_idx][i] = 0.f;
} }
...@@ -457,39 +421,15 @@ __device__ void paged_attention_kernel( ...@@ -457,39 +421,15 @@ __device__ void paged_attention_kernel(
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
V_vec v_vec[2];
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
const int start_row_idx = lane / NUM_V_VECS_PER_ROW;
if (start_row_idx < HEAD_SIZE) {
const int offset = start_row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec[0] = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec[0] = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(
v_quant_vec, kv_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec[0]);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
}
#pragma unroll #pragma unroll
for (int i = 1; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; V_vec v_vec;
reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// it to int64 because int32 can lead to overflow when this variable is // int64 because int32 can lead to overflow when this variable is multiplied
// multiplied by large numbers (e.g., kv_block_stride). For blocksparse // by large numbers (e.g., kv_block_stride).
// attention: skip computation on blocks that are not attended // For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars // blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx; const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset; int bs_block_offset;
...@@ -500,9 +440,8 @@ __device__ void paged_attention_kernel( ...@@ -500,9 +440,8 @@ __device__ void paged_attention_kernel(
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) * bs_block_offset =
blocksparse_head_sliding_step + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
...@@ -511,119 +450,75 @@ __device__ void paged_attention_kernel( ...@@ -511,119 +450,75 @@ __device__ void paged_attention_kernel(
} }
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue; continue;
} }
} }
if (!odd_nheads || head_idx < q_boundary) { if(!odd_nheads || head_idx < q_boundary) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(
logits + (reuse_kv_idx * partition_size) +
token_idx - start_token_idx)); const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
// scalar_t* logits_vec_ptr = + kv_head_idx * kv_head_stride;
// reinterpret_cast<scalar_t*>(&logits_vec); for(int i=0;i<8;++i){
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000); // from_float(*(logits_vec_ptr+i), 1000);
// } // }
if (reuse_kv_idx == 0) { if(reuse_kv_idx==0) {
const int row_idx = start_row_idx + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec[i % 2] = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec[i % 2] = v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>( kv_scale);
v_quant_vec, kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of // NOTE(woosuk): When v_vec contains the tokens that are out of the
// the context, we should explicitly zero out the values since // context, we should explicitly zero out the values since they may
// they may contain NaNs. See // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
reinterpret_cast<scalar_t*>(&v_vec[i % 2]);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
// if(threadIdx.x==0){ // if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); // scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = // scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// reinterpret_cast<scalar_t*>(&logits_vec); for(int // for(int i=0;i<8;++i){
// i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i])); // printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000); // // from_float(*(v_vec_ptr + i), 1000);
// } // }
// for(int i=0;i<8;++i){ // for(int i=0;i<8;++i){
// printf("logits_vec[%d] = // printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000); // // from_float(*(logits_vec_ptr + i), 1000);
// } // }
// } // }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec); // accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
} }
} }
accs[reuse_kv_idx][i-1] += accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
dot(logits_vec, v_vec[(i - 1) % 2]);
}
}
}
// tail
{
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES;
reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset = (tp_rank * num_heads + head_idx) *
blocksparse_head_sliding_step +
1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride ==
0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if (!odd_nheads || head_idx < q_boundary) {
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(
logits + (reuse_kv_idx * partition_size) +
token_idx - start_token_idx));
accs[reuse_kv_idx][NUM_ROWS_PER_THREAD - 1] +=
dot(logits_vec, v_vec[(NUM_ROWS_PER_THREAD - 1) % 2]);
} }
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for (int reuse_kv_idx = 0; reuse_kv_idx < REUSE_KV_TIMES; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx; int head_idx = head_idx_soffset + reuse_kv_idx;
if (!odd_nheads || head_idx < q_boundary) { if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i]; float acc = accs[reuse_kv_idx][i];
...@@ -645,12 +540,10 @@ __device__ void paged_attention_kernel( ...@@ -645,12 +540,10 @@ __device__ void paged_attention_kernel(
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i]; dst[row_idx] = accs[reuse_kv_idx][i];
} }
...@@ -660,13 +553,10 @@ __device__ void paged_attention_kernel( ...@@ -660,13 +553,10 @@ __device__ void paged_attention_kernel(
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
&out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) +
warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx]; accs[reuse_kv_idx][i] += src[row_idx];
} }
...@@ -679,8 +569,7 @@ __device__ void paged_attention_kernel( ...@@ -679,8 +569,7 @@ __device__ void paged_attention_kernel(
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -694,11 +583,14 @@ __device__ void paged_attention_kernel( ...@@ -694,11 +583,14 @@ __device__ void paged_attention_kernel(
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1, bool IS_BLOCK_SPARSE, bool odd_nheads = false> int REUSE_KV_TIMES = 1,
__global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel( bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
...@@ -729,9 +621,11 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel( ...@@ -729,9 +621,11 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v1_kernel(
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES, int PARTITION_SIZE, bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false> bool odd_nheads = false>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -754,18 +648,18 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel( ...@@ -754,18 +648,18 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel(
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
PARTITION_SIZE>( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, kv_block_stride, kv_head_stride, kv_scale, tp_rank,
alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel( int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -873,20 +767,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel( ...@@ -873,20 +767,19 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel< \ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ BLOCK_SIZE, NUM_THREADS, \
REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \ KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \ shared_mem_size); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, \ , dim3(grid), dim3(block), shared_mem_size, stream, \
IS_BLOCK_SPARSE, odd_nheads>), \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
dim3(grid), dim3(block), shared_mem_size, stream, out_ptr, query_ptr, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, scale, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ kv_scale, tp_rank, blocksparse_local_blocks, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, kv_scale, \ blocksparse_vert_stride, blocksparse_block_size, \
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ // #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ // vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...@@ -918,8 +811,8 @@ void paged_attention_v1_launcher( ...@@ -918,8 +811,8 @@ void paged_attention_v1_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128; int num_threads = 128;
if (num_heads != num_kv_heads) { if(num_heads!=num_kv_heads){
num_threads = 256; num_threads =256;
} }
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
...@@ -937,36 +830,25 @@ void paged_attention_v1_launcher( ...@@ -937,36 +830,25 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
REUSEKV_SWITCH_V1(num_heads * num_seqs, [&] { BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
BOOL_SWITCH(
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] { HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] { NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
REUSE_KV_TIMES * padded_max_seq_len * sizeof(float); int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
head_size * sizeof(float); // Keep that in sync with the logic here!
// Python-side check in
// vllm.worker.worker._check_if_can_support_max_seq_len Keep
// that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
if (num_heads == num_kv_heads) if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
shared_mem_size = ::max(12 * 1024, shared_mem_size); // int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// int shared_mem_size = ::max(31*1024, ::max(logits_size, // std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// outputs_size)); std::cout<<"shared_mem_size = dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
// "<<shared_mem_size<<std::endl;
dim3 grid((num_heads / num_kv_heads + REUSE_KV_TIMES - 1) /
REUSE_KV_TIMES * num_kv_heads,
1, num_seqs);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard( const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
device_of(query)); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
}); });
}); });
...@@ -1040,22 +922,20 @@ void paged_attention_v1( ...@@ -1040,22 +922,20 @@ void paged_attention_v1(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL( \ hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
(vllm::paged_attention_v2_kernel< \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
IS_BLOCK_SPARSE, REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>), \ , dim3(grid), dim3(block), shared_mem_size, stream, \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_head_sliding_step); \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL( \ PARTITION_SIZE>) \
(vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ , dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
PARTITION_SIZE>), \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -1099,35 +979,28 @@ void paged_attention_v2_launcher( ...@@ -1099,35 +979,28 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs, [&] { REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH( BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
(num_heads / num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] { HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES * PARTITION_SIZE * sizeof(float); int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int outputs_size = int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel. // For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs); // dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid; dim3 grid;
grid.x = (num_heads / num_kv_heads + REUSE_KV_TIMES - 1) / grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions; grid.y = max_num_partitions;
grid.z = num_seqs; grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, // int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
// outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel. // For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard( const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
device_of(query)); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
const hipStream_t stream =
at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment