precompute_k_lds_offset[i*2+j]=reinterpret_cast<size_t>(k_lds_v2fp16)+(k_lds_stage_offset+head_dim_idx*WARP_N*17+n_idx*32*17+j*4+i*32+k_ds_read_offset)*4/*4 bytes per dword*/;
// (lane_id & 1) * 16: in seqlen direction, [0,1,0,1,2,3,2,3], odd threads need skip 32 Halfs, 16 dword
// (laneid_and_15 >> 1) * 64: threads 0,1 occupy 4 lines, 4x32 Halfs, 64 dword.... 2,3 and 4,5 and 6,7 is the same
// laneid_and_15 >> 1, padding
// (laneid_shfl_4 & 1) * 8: threads 0,32 is even times of 16, thus 0,32; threads 16,48 is odd times of 16, thus 0,32,16,48; 0->16 need skip 16 Halfs, 8 dword
constexprintQ_LOAD_REQUESTS=(REUSE_KV_TIMES==0)?(kBlockM*kBlockK)/(4*32*WARP_NUM):MMAC_32x32?((REUSE_KV_TIMES+1)>>1)<<2/WARP_NUM:1/*MHA only need the first token*/;
constexprintQ_LOAD_REQUESTS=(REUSE_KV_TIMES==0)?(kBlockM*kBlockK)/(4*32*WARP_NUM):MMAC_32x32?((REUSE_KV_TIMES+1)>>1)<<2/WARP_NUM:1/*MHA only need the first token*/;
mla_reduce_max</*zero_init=*/false,vec4_Accum<softmaxType>,vec2_Accum<softmaxType>,M_WARP_COUNT,N_WARP_COUNT,M_MMAC_COUNT>(scores,scores_max,scores_max_cur);// scores_max is prev scores max
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
constsize_tk_smem_size=use_tile_16x32?Kernel_traits::k_smem_size/34*32*WARP_NUM/*16x32 tile no padding*/:Kernel_traits::k_smem_size*WARP_NUM/*32x32 tile use padding 32 -> 34*/;
constsize_tsmem_misalign=(params.seqlen_q>=16orKernel_traits::kHeadDimV==512)?Kernel_traits::kHeadDimV:(Kernel_traits::kHeadDimV+4)/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
constsize_tsmem_misalign=(params.seqlen_q>=16orKernel_traits::kHeadDimV==512)?Kernel_traits::kHeadDimV:(Kernel_traits::kHeadDimV+4)/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;