#pragma once #include #include #include #include #include "defines.h" #include "params.h" using namespace cute; namespace sm90::decode::sparse_fp8 { template class KernelTemplate { public: static_assert(NUM_HEADS == 64 || NUM_HEADS == 128 || NUM_HEADS == 16); // todo only support tp8 static constexpr int BLOCK_M = 16; static constexpr int NUM_M_BLOCKS = NUM_HEADS / BLOCK_M; static constexpr bool Is_causal = false; static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512; static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_ROPE = 64; static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE; static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding static constexpr int NUM_THREADS = 256; static constexpr int TOPK_BLOCK_SIZE = 64; using elem_type = cutlass::bfloat16_t; using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; static constexpr int kNWarps = 4; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch_16_16_32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_16_16_32 = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch_16x32_NT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32_NT, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<512>>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutRow = Layout, Stride<_1>>; using Element = cutlass::bfloat16_t; using ElementAccum = float; struct SharedMemoryPlan { union { struct { cute::array_aligned> smem_v; }; struct { // cute::array_aligned> smem_v_tmp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; // struct { // cute::array_aligned> smem_o; // // cute::array_aligned> smem_p; // // cute::array_aligned> smem_row_sum; // // cute::array_aligned> smem_row_max; // }; // struct { // cute::array_aligned> smem_q; // }; }; // array_aligned> q; // union { // array_aligned> k[NUM_K_BUFS]; // array_aligned> oBuf; // array_aligned> oAccumBuf; // } u; // CUTE_ALIGNAS(1024) array_aligned> s; // bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; // float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M]; // transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; }; // template< // typename Shape_Q, typename TMA_Q // > // using TiledMMA_QK = decltype(make_tiled_mma( // GMMA::MMA_64x64x16_F32BF16BF16_SS{}, // Layout>{} // )); // using TiledMMA_QK_rQ = decltype(make_tiled_mma( // GMMA::MMA_64x64x16_F32BF16BF16_RS{}, // Layout>{} // )); // using TiledMMA_PV_LocalP = decltype(make_tiled_mma( // GMMA::MMA_64x256x16_F32BF16BF16_RS{}, // Layout>{} // )); // using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( // GMMA::MMA_64x256x16_F32BF16BF16_SS{}, // Layout>{} // )); static __device__ __forceinline__ void compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams ¶ms, const DecodingSchedMeta& sched_meta, int batch_idx); static __device__ __forceinline__ void devfunc(const SparseAttnDecodeParams ¶ms); static void run(const SparseAttnDecodeParams ¶ms); }; }