for(intmi=0;mi<(WARP_M/32);++mi,rowcol.u32.x+=BLOCK_ROW_STRIDE){// when WARP_M > 32, attention, block_row_idx is computed by BLOCK_M / 32 rather than BLOCK_M / WARP_M
constintcol_idx_limit_right=std::min(max_seqlen_k,row_idx+max_seqlen_k-max_seqlen_q);// attention, when max_seqlen_k == max_seqlen_q, vgpr can be reduced again
// float scores_max_cur[WARP_M/16]; //calculate max of each row
DataType1scores_max_cur[(WARP_M/32)];
int8_kvcache_reduce_max</*zero_init=*/false,DataType0,DataType1,WARP_M,WARP_N,M_MMAC_COUNT>(scores,scores_max,scores_max_cur);// scores_max is prev scores max
constexprint__kHeadDim=(REUSE_KV_TIMES>=16orkHeadDim==512)?kHeadDim:kHeadDim+4/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;