Commit 76058dcd authored by zhangshao's avatar zhangshao
Browse files

解决pa tc 版读取Q矩阵时数组越界的问题

parent 3f42b83d
...@@ -235,7 +235,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -235,7 +235,7 @@ __device__ void paged_attention_kernel_TC(
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16]; __shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
//if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx); //if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx);
for(int i=0;i<REUSE_KV_TIMES;i++){ for(int i=0;i<q_boundary;i++){
if(thread_idx<16){ if(thread_idx<16){
q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8); q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
} }
...@@ -1195,4 +1195,4 @@ void paged_attention_v2_opt_tc( ...@@ -1195,4 +1195,4 @@ void paged_attention_v2_opt_tc(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment