usingFrgTypeA=UMMA::tmem_frg_1sm<a_type,a_type,UMMA::TmemAllocMode::NonInterleaved>;// Actually this should be "duplicated", however, our great CuTe doesn't allow us to set it to "duplicated", so we just set it to NonInterleaved for a correct address calculation
usingFrgTypeB=UMMA::smem_desc<b_major>;
usingFrgTypeC=UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256 bits; transform to units of elements
// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)
// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions
// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100.
// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async)
// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`.
// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT]
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
returnrow_idx;
// return row_idx;
}
// }
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx
// // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
// static constexpr int PEER_ADDR_MASK = 16777216;
// Given an address in the current CTA, return the corresponding address in the peer CTA
// // Given an address in the current CTA, return the corresponding address in the peer CTA
template<typenameT>
// template<typename T>
CUTE_DEVICE
// CUTE_DEVICE
T*get_peer_addr(constT*p){
// T* get_peer_addr(const T* p) {
return(T*)((int64_t)(p)^PEER_ADDR_MASK);
// return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
// }
// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
// // Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
// // cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel