#pragma once #include #include namespace sm90 { __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" :: "r"(dst_addr), "l"(src), "n"(16)); } __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n" :: "r"(dst_addr), "l"(src), "r"(pred?16:0), "l"(cache_policy)); } __forceinline__ __device__ int64_t createpolicy_evict_last() { int64_t res; asm volatile( "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" : "=l"(res) : ); return res; } __forceinline__ __device__ int64_t createpolicy_evict_first() { int64_t res; asm volatile( "createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t" : "=l"(res) : ); return res; } __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } __forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); return col_idx; } // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h // * Copyright (c) 2024, Tri Dao. template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { using namespace cute; constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } // A simpler version of gemm template __forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sA_frag = thr_mma.partition_fragment_A(sA); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(sA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(sA_frag); ++k) { cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); } template __forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(rA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(const_cast(rA_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(rA_frag); ++k) { cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); warpgroup_fence_operand(const_cast(rA_frag)); } __forceinline__ __device__ uint32_t get_sm_id() { uint32_t ret; asm("mov.u32 %0, %%smid;" : "=r"(ret)); return ret; } static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs. template CUTE_DEVICE T* get_peer_addr(const T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); } template< typename TMA, typename Tensor0, typename Tensor1 > CUTE_DEVICE void launch_tma_copy( const TMA &tma_copy, Tensor0 src, Tensor1 dst, cutlass::arch::ClusterTransactionBarrier &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL ) { auto thr_tma = tma_copy.get_slice(cute::_0{}); cute::copy( tma_copy.with(reinterpret_cast(bar), 0, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); } }