#pragma once #include #include #include #include #include "config.h" using namespace cute; template struct Traits { using InputT = InputT_; static constexpr bool Is_causal = Is_causal_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v); static constexpr int kBlockM = BLOCK_SIZE_M; static constexpr int kBlockN = PAGE_BLOCK_SIZE; static constexpr int kHeadDim = HEAD_DIM_K; static constexpr int kHeadDimV = HEAD_DIM_V; static constexpr int kNWarps = 4; using Element = InputT; using elem_type = Element; using ElementAccum = float; using SmemLayoutRow = Layout, Stride<_1>>; static constexpr int kSwizzle = 3; using SmemLayoutAtomK = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 64>>{})); struct SharedMemoryPlan { }; };