for(intmi=0;mi<(WARP_M/32);++mi,rowcol.u32.x+=BLOCK_ROW_STRIDE){// when WARP_M > 32, attention, block_row_idx is computed by BLOCK_M / 32 rather than BLOCK_M / WARP_M
constintcol_idx_limit_right=std::min(max_seqlen_k,row_idx+max_seqlen_k-max_seqlen_q);// attention, when max_seqlen_k == max_seqlen_q, vgpr can be reduced again
// float scores_max_cur[WARP_M/16]; //calculate max of each row
DataType1scores_max_cur[(WARP_M/32)];
int8_kvcache_reduce_max</*zero_init=*/false,DataType0,DataType1,WARP_M,WARP_N,M_MMAC_COUNT>(scores,scores_max,scores_max_cur);// scores_max is prev scores max
constexprint__kHeadDim=(REUSE_KV_TIMES>=16orkHeadDim==512)?kHeadDim:kHeadDim+4/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
// (lane_id & 1) * 16: in seqlen direction, [0,1,0,1,2,3,2,3], odd threads need skip 32 Halfs, 16 dword
// (laneid_and_15 >> 1) * 64: threads 0,1 occupy 4 lines, 4x32 Halfs, 64 dword.... 2,3 and 4,5 and 6,7 is the same
// laneid_and_15 >> 1, padding
// (laneid_shfl_4 & 1) * 8: threads 0,32 is even times of 16, thus 0,32; threads 16,48 is odd times of 16, thus 0,32,16,48; 0->16 need skip 16 Halfs, 8 dword
for(intmi=0;mi<(WARP_M/32);++mi,rowcol.u32.x+=BLOCK_ROW_STRIDE){// when WARP_M > 32, attention, block_row_idx is computed by BLOCK_M / 32 rather than BLOCK_M / WARP_M
kvcache_reduce_max</*zero_init=*/false,DataType0,DataType1,WARP_M,WARP_N,M_MMAC_COUNT>(scores,scores_max,scores_max_cur);// scores_max is prev scores max