#pragma once #include #include #include #include #include "config.h" using namespace cute; template struct Traits { using InputT = InputT_; static constexpr bool Is_causal = Is_causal_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v || std::is_same_v); static constexpr int kBlockM = BLOCK_SIZE_M; static constexpr int kBlockN = PAGE_BLOCK_SIZE; static constexpr int kHeadDim = HEAD_DIM_K; static constexpr int kHeadDimV = HEAD_DIM_V; static constexpr int kNWarps = 4; using Element = InputT; using elem_type = Element; using ElementAccum = float; using SmemLayoutRow = Layout, Stride<_1>>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<16 * 32>>{})); using SmemLayoutK_place_holder = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<15 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomQ = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using ValLayoutMNK = Layout>; // #if defined(__gfx936__) || defined(__gfx938__) using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // #elif defined(__gfx928__) // using MMA_Atom_Arch = std::conditional_t< // std::is_same_v, // MMA_Atom, // MMA_Atom // >; // using TiledMma = TiledMMA< // MMA_Atom_Arch, // Layout, _1>>, // 1x4x1 or 1x8x1 thread group // ValLayoutMNK>; // #endif using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomK = Layout, Stride< _4, _1>>; using GmemTiledCopyK = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomK{}, Layout>{})); using GmemTiledCopyV = GmemTiledCopyK; struct SharedMemoryPlan { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; }; }; }; template struct Traits_Block_M_64 { using InputT = InputT_; static constexpr bool Is_causal = Is_causal_; static constexpr int BLOCK_SIZE_M = 64; static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v || std::is_same_v); static constexpr int kBlockM = BLOCK_SIZE_M; static constexpr int kBlockN = PAGE_BLOCK_SIZE; static constexpr int kHeadDim = HEAD_DIM_K; static constexpr int kHeadDimV = HEAD_DIM_V; static constexpr int kNWarps = 4; using Element = InputT; using elem_type = Element; using ElementAccum = float; };