Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary
Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout:
// We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load
Tile<Int<128>,Layout<Shape<_128,_2,_2>,Stride<_1,_256,_128>>,_16>{}// We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
Tile<Int<128>,Layout<Shape<_128,_2,_2>,Stride<_1,_256,_128>>,_16>{}// We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
boolhave_valid_indices=__any_sync(0xffffffff,li!=0);// Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld
boolhave_valid_indices=__any_sync(0xffffffff,li!=0);// Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if(!have_valid_indices){
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
// NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid.
// NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs.
// NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16,bf16,float,B_H,128,UMMA::Major::K,UMMA::Major::K>{}// Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: <TODO Fill link here>
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
floatnew_max,scale_for_old;
if(!should_scale_o){
// Don't scale O
scale_for_old=1.0f;
new_max=mi;
}else{
new_max=max(cur_pi_max,mi);
scale_for_old=exp2f(mi-new_max);
}
mi=new_max;// mi is still identical within each row
boolhave_valid_indices=__any_sync(0xffffffff,li!=0);// Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if(!have_valid_indices){
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
);// NOTE Using cp.async instead of TMA is faster here
// NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside)
Tile<Int<128>,Layout<Shape<_128,_2,_2>,Stride<_1,_256,_128>>,_16>{}// We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
cute::tma_store_wait<0>();// This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document)
inttma_coords_step_per_block=params.stride_kv_block/TMA_K_STRIDE_FOR_DECODING;// must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512
tma_coords[i]=is_token_valid?block_idx*cur_tma_coords_step_per_block+idx_in_block*tma_coords_step_per_token:-1;// If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
int64_toffset=block_idx*cur_k_block_stride+(idx_in_block*8+(cta_idx==1?4:0));// Each token has 7 scale factors with an extra 1B padding
// NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf
KU_ASSERT(params.topk%B_TOPK==0);// To save some boundry checkings
KU_ASSERT(params.h_q==H_Q);// To save some calculation
KU_ASSERT(params.d_qk==D_QK);
static_assert(D_Q==512);
CUtensorMaptensor_map_q;
ifconstexpr(IS_DECODE){
KU_ASSERT(params.stride_q_b%params.stride_q_s_q==0,"In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension).");
KU_ASSERT((int64_t)k_ptr%16==0,"The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f",is_extra?"extra_":"",k_ptr);
KU_ASSERT(stride_kv_block%TMA_K_STRIDE_FOR_DECODING==0,"%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary",is_extra?"extra_":"",stride_kv_block,TMA_K_STRIDE_FOR_DECODING);
);// NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens.