Commit e2bd7e16 authored by zhangshao's avatar zhangshao
Browse files

解决pa v2 bug

parent 304e2bab
...@@ -838,7 +838,6 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -838,7 +838,6 @@ void paged_attention_v1_launcher_opt_tc(
get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads); get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
...@@ -855,6 +854,8 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -855,6 +854,8 @@ void paged_attention_v1_launcher_opt_tc(
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac); // printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs); dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
}); });
}); });
...@@ -897,7 +898,7 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -897,7 +898,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v1( void paged_attention_v1_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
...@@ -935,7 +936,7 @@ void paged_attention_v1_opt_tc( ...@@ -935,7 +936,7 @@ void paged_attention_v1_opt_tc(
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
...@@ -961,7 +962,7 @@ void paged_attention_v1_opt_tc( ...@@ -961,7 +962,7 @@ void paged_attention_v1_opt_tc(
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \ (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \ PARTITION_SIZE>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \ dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
...@@ -1049,7 +1050,6 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1049,7 +1050,6 @@ void paged_attention_v2_launcher_opt_tc(
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
...@@ -1061,6 +1061,8 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1061,6 +1061,8 @@ void paged_attention_v2_launcher_opt_tc(
grid.z = num_seqs; grid.z = num_seqs;
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
}); });
}); });
...@@ -1105,7 +1107,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1105,7 +1107,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v2( void paged_attention_v2_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
...@@ -1151,7 +1153,7 @@ void paged_attention_v2_opt_tc( ...@@ -1151,7 +1153,7 @@ void paged_attention_v2_opt_tc(
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment