Commit 76058dcd authored by zhangshao's avatar zhangshao
Browse files

解决pa tc 版读取Q矩阵时数组越界的问题

parent 3f42b83d
......@@ -235,7 +235,7 @@ __device__ void paged_attention_kernel_TC(
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
//if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx);
for(int i=0;i<REUSE_KV_TIMES;i++){
for(int i=0;i<q_boundary;i++){
if(thread_idx<16){
q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment