#pragma once #include #include #include #include #include "config.h" using namespace cute; template struct Traits { using InputT = InputT_; static constexpr bool Is_causal = Is_causal_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v); static constexpr int kBlockM = BLOCK_SIZE_M; static constexpr int kBlockN = PAGE_BLOCK_SIZE; static constexpr int kHeadDim = HEAD_DIM_K; static constexpr int kHeadDimV = HEAD_DIM_V; static constexpr int kNWarps = 4; using Element = InputT; using elem_type = Element; using ElementAccum = float; using SmemLayoutRow = Layout, Stride<_1>>; static constexpr int kSwizzle = 3; using SmemLayoutAtomQ = Layout, Int<64>>, Stride, _1>>; using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 64>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomK_place_holder = Layout, Int<64>>, Stride<_64, _1>>; using SmemLayoutK_place_holder = decltype(tile_to_shape( SmemLayoutAtomK_place_holder{}, Shape, Int<7*64>>{})); using MMA_Atom_Arch = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>;// using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; struct SharedMemoryPlan { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; }; }; };