SM100_MMA_F16BF16_WS_TS_NOELECT<bf16,bf16,float,B_H,128,UMMA::Major::K,UMMA::Major::K>{}// Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: <TODO Fill link here>
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
floatnew_max,scale_for_old;
if(!should_scale_o){
// Don't scale O
scale_for_old=1.0f;
new_max=mi;
}else{
new_max=max(cur_pi_max,mi);
scale_for_old=exp2f(mi-new_max);
}
mi=new_max;// mi is still identical within each row
boolhave_valid_indices=__any_sync(0xffffffff,li!=0);// Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if(!have_valid_indices){
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
);// NOTE Using cp.async instead of TMA is faster here
// NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside)
Tile<Int<128>,Layout<Shape<_128,_2,_2>,Stride<_1,_256,_128>>,_16>{}// We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
cute::tma_store_wait<0>();// This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document)
inttma_coords_step_per_block=params.stride_kv_block/TMA_K_STRIDE_FOR_DECODING;// must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512
tma_coords[i]=is_token_valid?block_idx*cur_tma_coords_step_per_block+idx_in_block*tma_coords_step_per_token:-1;// If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
int64_toffset=block_idx*cur_k_block_stride+(idx_in_block*8+(cta_idx==1?4:0));// Each token has 7 scale factors with an extra 1B padding
// NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf
KU_ASSERT(params.topk%B_TOPK==0);// To save some boundry checkings
KU_ASSERT(params.h_q==H_Q);// To save some calculation
KU_ASSERT(params.d_qk==D_QK);
static_assert(D_Q==512);
CUtensorMaptensor_map_q;
ifconstexpr(IS_DECODE){
KU_ASSERT(params.stride_q_b%params.stride_q_s_q==0,"In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension).");
KU_ASSERT((int64_t)k_ptr%16==0,"The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f",is_extra?"extra_":"",k_ptr);
KU_ASSERT(stride_kv_block%TMA_K_STRIDE_FOR_DECODING==0,"%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary",is_extra?"extra_":"",stride_kv_block,TMA_K_STRIDE_FOR_DECODING);
);// NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens.
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows:
// - Wait for the 0-th tile to be ready using `barrier.wait()`
// - Compute Q K^T for the 0-th tile
// - Wait for the 1-st tile to be ready
// - Compute Q K^T for the 1-st tile
// ...
// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation
// Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline.
// We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise
// nvcc cannot correctly analyse the loop, and will think that we are using accumulator
// registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized.
// This is also the reason why we put QK^T here, instead of the first operation in the loop
cute::warpgroup_wait<0>();
}
// A helper function for determining the length of the causal mask for one q token
// An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx).
// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
// Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation.
// Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks
// NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
staticconstexprintNUM_TOKENS_PER_ROUND=32;// If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds)
token_index=-1;// To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length
}
}
intblock_index=token_index==-1?0:(int)((uint32_t)token_index/(uint32_t)page_block_size);// Use uint32_t division and mod to improve performance
intrel_idx_in_block=(uint32_t)token_index%(uint32_t)page_block_size;// NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
fp8x16cur_fp8x16=load_128b_from_gmem<fp8x16,L1CacheHint::EVICT_LAST,L2PrefetchHint::B256>(gK_nope+dim_idx*64);// We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1)
// __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
dim3((params.h_q/B_H)*params.s_q,1,1),// NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)