staticconstexprintNUM_SCALES_EACH_TOKEN=MODEL_TYPE==ModelType::V32?4:8;// Padding is included
staticconstexprintTMA_K_STRIDE=MODEL_TYPE==ModelType::V32?D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE):D_NOPE+2*D_ROPE;// Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
bf16q_sw64[B_H*D_Q_SW64];// NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
floatnew_max,scale_for_old;
if(!should_scale_o){
// Don't scale O
scale_for_old=1.0f;
new_max=mi;
}else{
new_max=max(cur_pi_max,mi);
scale_for_old=exp2f(mi-new_max);
}
mi=new_max;// mi is still identical within each row
floato_scale=li==0.0f?0.0f:__fdividef(1.0f,li);// Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times
TensorsV=make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()),SmemLayoutKTilesTransposed_SW128<D_V/64>{});// NOTE: For MODEL1, it "expands" to the RoPE part.
tma_coords[i]=is_token_valid?block_idx*cur_tma_coords_step_per_block+idx_in_block*tma_coords_step_per_token:-1;// If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
KU_ASSERT(params.topk%B_TOPK==0,"topk (%d) mod B_TOPK (%d) must be 0",params.topk,B_TOPK);
KU_ASSERT(params.extra_topk%B_TOPK==0,"extra_topk (%d) mod B_TOPK (%d) must be 0",params.extra_topk,B_TOPK);
KU_ASSERT(params.h_q==B_H);
KU_ASSERT(params.h_kv==1);
KU_ASSERT(params.d_qk==D_Q);
KU_ASSERT(params.d_v==D_V);
ifconstexpr(MODEL_TYPE==ModelType::MODEL1){
constexprintBYTES_PER_TOKEN=D_NOPE+2*D_ROPE+8;
KU_ASSERT(params.stride_kv_row==BYTES_PER_TOKEN,"Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1");// Each block must be contiguous
KU_ASSERT((int64_t)k_ptr%16==0,"The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f",is_extra?"extra_":"",k_ptr);
KU_ASSERT(k_batch_stride%TMA_K_STRIDE==0,"%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary",is_extra?"extra_":"",k_batch_stride,TMA_K_STRIDE);
usingTMEM_LOAD=std::conditional_t<kCorrectionTileSize==32,SM100_TMEM_LOAD_32dp32b32x,SM100_TMEM_LOAD_32dp32b16x>;// 4x32 threads with 64 cols of 32b elem
usingTMEM_LOAD=std::conditional_t<kCorrectionTileSize==32,SM100_TMEM_LOAD_32dp32b32x,SM100_TMEM_LOAD_32dp32b16x>;// 4x32 threads with 64 cols of 32b elem