#pragma once #include #include #include #include "defines.h" #include "params.h" namespace sm90::fwd { using namespace cute; template class KernelTemplate { public: static constexpr int D_Q = D_QK; static constexpr int D_K = D_QK; static constexpr int D_V = 512; static constexpr int kNWarps = 4; static constexpr int B_H = 16; static constexpr int B_TOPK = 64; // TopK block size static constexpr int NUM_THREADS = kNWarps * 64; static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) using Element = cutlass::bfloat16_t; using elem_type = Element; using ElementAccum = float; using index_t = int64_t; static constexpr int kBlockM = B_H; static constexpr int kBlockN = B_TOPK; static constexpr int kHeadDim = D_QK; static constexpr int kHeadDimV = D_V; using ValLayoutMNK = Layout>; // 没打开? // #if defined(__gfx936__) || defined(__gfx938__) || 1 // using MMA_Atom_Arch = std::conditional_t< // std::is_same_v, // MMA_Atom, // MMA_Atom // >; using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // #endif using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using SmemLayoutAtomQ = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<16 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutRow = Layout, Stride<_1>>; using SmemLayoutK_place_holder = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<4 * 32>>{})); struct SharedMemoryPlan { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_place_holder; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; }; // transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; }; static __device__ __forceinline__ void devfunc(const SparseAttnFwdParams ¶ms); static void run(const SparseAttnFwdParams ¶ms); }; };