Commit e2bd7e16 authored by zhangshao's avatar zhangshao
Browse files

解决pa v2 bug

parent 304e2bab
......@@ -838,7 +838,6 @@ void paged_attention_v1_launcher_opt_tc(
get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
......@@ -855,6 +854,8 @@ void paged_attention_v1_launcher_opt_tc(
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
});
});
......@@ -897,7 +898,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \
}
void paged_attention_v1(
void paged_attention_v1_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
......@@ -935,7 +936,7 @@ void paged_attention_v1_opt_tc(
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
......@@ -961,7 +962,7 @@ void paged_attention_v1_opt_tc(
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
......@@ -1049,7 +1050,6 @@ void paged_attention_v2_launcher_opt_tc(
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
......@@ -1061,6 +1061,8 @@ void paged_attention_v2_launcher_opt_tc(
grid.z = num_seqs;
dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
});
});
......@@ -1105,7 +1107,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \
}
void paged_attention_v2(
void paged_attention_v2_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
......@@ -1151,7 +1153,7 @@ void paged_attention_v2_opt_tc(
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment