// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
cute::cluster_sync();// We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`"
staticconstexprintNUM_TOKENS_PER_ROUND=32;// If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds)
intrel_idx_in_block=(token_index+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE;// NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error
token_index=-1;// To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length
}
}
intblock_index=token_index==-1?0:(int)((uint32_t)token_index/(uint32_t)page_block_size);// Use uint32_t division and mod to improve performance
intrel_idx_in_block=(uint32_t)token_index%(uint32_t)page_block_size;// NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
fp8x16cur_fp8x16=load_128b_from_gmem<fp8x16,L1CacheHint::EVICT_LAST,L2PrefetchHint::B256>(gK_nope+dim_idx*64);// We use EVICT_LAST here since gK_base may not be aligned to 32B
fp8x16cur_fp8x16=load_128b_from_gmem<fp8x16,L1CacheHint::EVICT_LAST,L2PrefetchHint::B256>(gK_nope+dim_idx*64);// We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1)
cute::cluster_sync();// Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync()
sync_all_threads_in_cluster();
}
}
#else
if(cute::thread0()){
...
...
@@ -541,50 +677,82 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
KU_ASSERT(params.stride_kv_row==BYTES_PER_TOKEN,"Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1");// Each block must be contiguous
if(params.extra_kv!=nullptr){
KU_ASSERT(params.stride_extra_kv_row==BYTES_PER_TOKEN,"Each page block in extra KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1");// Each block must be contiguous
}
}else{
KU_ASSERT(params.extra_kv==nullptr,"V3.2 does not support extra KV cache");
KU_ASSERT(params.topk_length==nullptr,"V3.2 does not support dynamic topk length");
KU_ASSERT(params.stride_kv_row==656);// number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16)
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
array_aligned<bf16,cosize_v<SmemLayoutS>>s[D_QK==576?1:2];// For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers
boolis_kv_valid[2][B_TOPK];
float2sM[32];
float2sL[64];// For reduction across WG0/1 in epilogue
// NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md
dim3((params.h_q/B_H)*params.s_q,1,1),// NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)