#pragma once #include #include #include #include #include "config.h" using namespace cute; template struct Traits { using InputT = InputT_; static constexpr bool Is_causal = Is_causal_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v); static constexpr int kBlockM = BLOCK_SIZE_M; static constexpr int kBlockN = PAGE_BLOCK_SIZE; static constexpr int kHeadDim = HEAD_DIM_K; static constexpr int kHeadDimV = HEAD_DIM_V; static constexpr int kNWarps = 4; static constexpr int kSwizzle = 3; using Element = InputT; using elem_type = Element; using ElementAccum = float; using ValLayoutMNK = Layout>; using SmemLayoutRow = Layout, Stride<_1>>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<8 * 32>>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomV_fp8 = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutV_fp8 = decltype(tile_to_shape( SmemLayoutAtomV_fp8{}, Shape, Int<512>>{})); using SmemLayoutVtransposed_fp8 = decltype( composition(SmemLayoutV_fp8{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle_fp8 = decltype(get_nonswizzle_portion(SmemLayoutVtransposed_fp8{})); using SmemLayoutAtomQ = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using MMA_Atom_Arch_16_16_32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_16_16_32 = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16_32_16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O_16_32_16 = TiledMMA< MMA_Atom_Arch_16_32_16, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch_int8 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_int8 = TiledMMA< MMA_Atom_Arch_int8, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); struct SharedMemoryPlan { union { struct { cute::array_aligned> smem_v; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; }; }; };