Commit 1c2aa04c authored by zhuwenwen's avatar zhuwenwen
Browse files

update attention_kernels_opt_tc.cu

parent 83bb8f5a
......@@ -613,7 +613,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -793,7 +793,7 @@ void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher_opt(
void paged_attention_v1_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
......@@ -862,7 +862,7 @@ void paged_attention_v1_launcher_opt(
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v1_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
......@@ -915,7 +915,7 @@ void paged_attention_v1(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
......@@ -959,7 +959,7 @@ void paged_attention_v1_opt(
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
......@@ -1000,7 +1000,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher_opt(
void paged_attention_v2_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
......@@ -1069,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
......@@ -1127,7 +1127,7 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment