#pragma once // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } template< typename TMA, typename Tensor0, typename Tensor1 > CUTE_DEVICE void launch_tma_copy( const TMA &tma_copy, const Tensor0 &src, Tensor1 &dst, transac_bar_t &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL, const uint16_t &multicast_mask = 0 ) { auto thr_tma = tma_copy.get_slice(_0{}); cute::copy( tma_copy.with(reinterpret_cast(bar), multicast_mask, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); } template CUTE_DEVICE static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { long2 data_long2 = *reinterpret_cast(&data); uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); asm volatile ( "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" : : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) ); } static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 template CUTE_DEVICE T* get_peer_addr(const T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); }