Commit 1c2aa04c authored by zhuwenwen's avatar zhuwenwen
Browse files

update attention_kernels_opt_tc.cu

parent 83bb8f5a
...@@ -613,7 +613,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -613,7 +613,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt( __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -793,7 +793,7 @@ void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize ...@@ -793,7 +793,7 @@ void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher_opt( void paged_attention_v1_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
...@@ -862,7 +862,7 @@ void paged_attention_v1_launcher_opt( ...@@ -862,7 +862,7 @@ void paged_attention_v1_launcher_opt(
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v1_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
...@@ -915,7 +915,7 @@ void paged_attention_v1( ...@@ -915,7 +915,7 @@ void paged_attention_v1(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt( void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
...@@ -959,7 +959,7 @@ void paged_attention_v1_opt( ...@@ -959,7 +959,7 @@ void paged_attention_v1_opt(
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \ blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \ (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \ PARTITION_SIZE>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \ dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
...@@ -1000,7 +1000,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize ...@@ -1000,7 +1000,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher_opt( void paged_attention_v2_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
...@@ -1069,7 +1069,7 @@ void paged_attention_v2_launcher_opt( ...@@ -1069,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
...@@ -1127,7 +1127,7 @@ void paged_attention_v2( ...@@ -1127,7 +1127,7 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt( void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment