#pragma once #include #include #include #include #include "defines.h" #include "params.h" using namespace cute; namespace sm90::decode::sparse_fp8 { template class KernelTemplate { public: static_assert(NUM_HEADS == 64 || NUM_HEADS == 128); static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64; static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS; static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512; static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_ROPE = 64; static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE; static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding static constexpr int NUM_THREADS = 128*3; static constexpr int BLOCK_M = 64; static constexpr int TOPK_BLOCK_SIZE = 64; static constexpr int NUM_K_BUFS = 2; using SmemLayoutQTile = decltype(tile_to_shape( GMMA::Layout_SW128_Atom{}, Shape, Int<64>>{} )); template using SmemLayoutQTiles = decltype(tile_to_shape( SmemLayoutQTile{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} )); using SmemLayoutQ = SmemLayoutQTiles; using SmemLayoutKTile = decltype(tile_to_shape( GMMA::Layout_INTER_Atom{}, Shape, _64>{}, Step<_1, _2>{} )); template using SmemLayoutKTiles = decltype(tile_to_shape( SmemLayoutKTile{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} )); template using SmemLayoutKTilesTransposed = decltype(composition( SmemLayoutKTiles{}, Layout, Int>, Stride, _1>>{} )); static constexpr int OBUF_SW = 64; using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom; using SmemLayoutOBuf = decltype(tile_to_shape( SmemLayoutOBufAtom{}, Shape, Int>{}, Step<_1, _2>{} )); using SmemLayoutOAccumBuf = Layout< Shape, Int>, Stride, _1> // We use stride = 520 here to avoid bank conflict >; using SmemLayoutK = SmemLayoutKTiles; using SmemLayoutV = SmemLayoutKTilesTransposed; using SmemLayoutHalfV = SmemLayoutKTilesTransposed; using SmemLayoutS = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); struct SharedMemoryPlan { // array_aligned> q; // union { // array_aligned> k[NUM_K_BUFS]; // array_aligned> oBuf; // array_aligned> oAccumBuf; // } u; // CUTE_ALIGNAS(1024) array_aligned> s; // bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; // float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M]; // transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; }; // template< // typename Shape_Q, typename TMA_Q // > // using TiledMMA_QK = decltype(make_tiled_mma( // GMMA::MMA_64x64x16_F32BF16BF16_SS{}, // Layout>{} // )); // using TiledMMA_QK_rQ = decltype(make_tiled_mma( // GMMA::MMA_64x64x16_F32BF16BF16_RS{}, // Layout>{} // )); // using TiledMMA_PV_LocalP = decltype(make_tiled_mma( // GMMA::MMA_64x256x16_F32BF16BF16_RS{}, // Layout>{} // )); // using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( // GMMA::MMA_64x256x16_F32BF16BF16_SS{}, // Layout>{} // )); static __device__ __forceinline__ void devfunc(const SparseAttnDecodeParams ¶ms); static void run(const SparseAttnDecodeParams ¶ms); }; }