/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include #include "cutlass/float8.h" using namespace cute; template struct Flash_kernel_traits { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using Element = elem_type; static constexpr bool Has_cp_async = true; #elif defined(DCU_ASM) using Element = elem_type; static constexpr bool Has_cp_async = false; #else using Element = cutlass::half_t; static constexpr bool Has_cp_async = false; #endif using ElementAccum = float; using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; #else using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_FOR_GEMM1 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using ValLayoutMNK = Layout>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #else using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #endif }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template > struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; #ifndef GEMM1_AMATRIX_WITH_SMEM using TiledMma_FOR_GEMM1 = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; #endif using SmemLayoutAtomQ = decltype( composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); // using SmemLayoutAtomQ = decltype( // composition(Swizzle<3, 2, 4>{}, // // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 // Layout>, // Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); #ifdef GEMM1_AMATRIX_WITH_SMEM using SmemLayoutAtomAccs = decltype( composition(Swizzle<3, 2, 4>{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); using SmemLayoutAccs = decltype(tile_to_shape( SmemLayoutAtomAccs{}, Shape, Int>{})); #endif using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); #ifdef GEMM1_AMATRIX_WITH_SMEM // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 using SmemLayoutVtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); #else // headdim为128时,2、4、2也能有效减少bank冲突,但是测试性能没有提升,因此采用了两套参数 // using SmemLayoutAtomV = decltype( // composition(Swizzle{}, // // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 // Layout>, // Stride, _1>>{})); using SmemLayoutAtomV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); #endif using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); #ifdef GEMM1_AMATRIX_WITH_SMEM static constexpr int KSmemAccsSize = size(SmemLayoutAccs{}) * sizeof(Element); #endif static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); #ifdef GEMM1_AMATRIX_WITH_SMEM // 例如headdim为32的情况下,kSmemQSize是大于KSmemAccsSize的 static constexpr int kSmemSize = Share_Q_K_smem ? (kSmemQSize > KSmemAccsSize ? std::max(kSmemQSize, kSmemKVSize) : KSmemAccsSize + kSmemKSize) : kSmemQSize + kSmemKVSize; #else static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; #endif static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read // from how many rows does each thread have to fetch static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); // Here we assign a contiguous tile to each thread, rather than a 1x8 row every // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread // do not cross a page boundary. This way, each thread need only fetch 1 page index per // mainloop iteration. R>udimentary testing shows no slowdown. using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); #if 0 using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store #else using GmemLayoutAtomO = Layout, Stride< _4, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); // Val layout, 8 vals per store #endif #if 0 using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store #else using GmemLayoutAtomOaccum = Layout, // Thread layout, 8 threads per row Stride< _4, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store #endif using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Kernel traits for attnmask (uses 16x64x32 MMA for Q*K and specialized SmemCopyAtomV). //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Flash_fwd_kernel_traits_attnmask : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = Copy_Atom, Element>; using SmemCopyAtomTransposed = Copy_Atom; using SmemCopyAtomV = Copy_Atom; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = 32; static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; // Q*K GEMM: use 16x64x32 MMA (same as predmaskbeta bias_hdim32) using MMA_Atom_Arch_16x64x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch_16x64x32, Layout,_1,_1>>, typename Base::ValLayoutMNK>; // P*V GEMM: use standard 16x16x16 MMA with ValLayout <1,2,1> using TiledMma_FOR_GEMM1 = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, Layout>>; using SmemLayoutAtomQ = decltype( composition(Swizzle<3, 3, 3>{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomV = Layout>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemLayoutAtomO = Layout, Stride< _4, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _4, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; template > struct Flash_fwd_kernel_16x64_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_K_V_smem = Share_K_V_smem_; static constexpr bool Is_Q_use_smem = Is_Q_use_smem_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr bool MMA_Atom_Use_K16 = Is_Q_use_smem; static constexpr bool MMA_Atom_Use_K32 = !MMA_Atom_Use_K16; using SmemCopyAtom16x64x16 = typename Base::SmemCopyAtom; // using SmemCopyAtom16x64x32 = Copy_Atom; using SmemCopyAtom16x64x32 = typename Base::SmemCopyAtom; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using SmemCopyAtom = std::conditional_t< MMA_Atom_Use_K16, SmemCopyAtom16x64x16, SmemCopyAtom16x64x32 >; using MMA_Atom_Arch_16x64 = std::conditional_t< MMA_Atom_Use_K16, MMA_Atom_Arch_16x64x16, MMA_Atom_Arch_16x64x32 >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype( composition(Swizzle<3, 3, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int>{})); // using SmemLayoutAtomV = Layout, Int>, // Stride, _1>>; using SmemLayoutAtomV = decltype( composition(Swizzle<1, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); static constexpr int kSmemKVSize = Share_K_V_smem ? kSmemKSize : 2 * kSmemKSize; static constexpr int kSmemSize = kHeadDim == 64 ? 32 * 1024 : Is_Q_use_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read // from how many rows does each thread have to fetch static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); // Here we assign a contiguous tile to each thread, rather than a 1x8 row every // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread // do not cross a page boundary. This way, each thread need only fetch 1 page index per // mainloop iteration. R>udimentary testing shows no slowdown. using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_traits_MLA : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_K_V_smem = Share_K_V_smem_; static constexpr bool Is_Q_use_smem = Is_Q_use_smem_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr bool MMA_Atom_Use_K16 = Is_Q_use_smem; static constexpr bool MMA_Atom_Use_K32 = !MMA_Atom_Use_K16; using SmemCopyAtom16x64x16 = typename Base::SmemCopyAtom; // using SmemCopyAtom16x64x32 = Copy_Atom; using SmemCopyAtom16x64x32 = typename Base::SmemCopyAtom; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using SmemCopyAtom = std::conditional_t< MMA_Atom_Use_K16, SmemCopyAtom16x64x16, SmemCopyAtom16x64x32 >; using MMA_Atom_Arch_16x64 = std::conditional_t< MMA_Atom_Use_K16, MMA_Atom_Arch_16x64x16, MMA_Atom_Arch_16x64x32 >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype( composition(Swizzle<3, 3, 5>{}, Layout>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; // using SmemLayoutAtomV = decltype( // composition(Swizzle<1, 3, 3>{}, // Layout, Int<32>>, // Stride, _1>>{})); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); static constexpr int kSmemKVSize = Share_K_V_smem ? kSmemKSize : 2 * kSmemKSize; // 写出过lds有性能提升,128*128*2 = 32768 static constexpr int kSmemSize = 32768; // static constexpr int kSmemSize = Is_Q_use_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read // from how many rows does each thread have to fetch static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); // Here we assign a contiguous tile to each thread, rather than a 1x8 row every // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread // do not cross a page boundary. This way, each thread need only fetch 1 page index per // mainloop iteration. R>udimentary testing shows no slowdown. using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_traits_splitkv : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_K_V_smem = Share_K_V_smem_; static constexpr bool Is_Q_use_smem = Is_Q_use_smem_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr bool MMA_Atom_Use_K16 = Is_Q_use_smem; static constexpr bool MMA_Atom_Use_K32 = !MMA_Atom_Use_K16; using SmemCopyAtom16x64x16 = typename Base::SmemCopyAtom; // using SmemCopyAtom16x64x32 = Copy_Atom; using SmemCopyAtom16x64x32 = typename Base::SmemCopyAtom; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using SmemCopyAtom = std::conditional_t< MMA_Atom_Use_K16, SmemCopyAtom16x64x16, SmemCopyAtom16x64x32 >; using MMA_Atom_Arch_16x64 = std::conditional_t< MMA_Atom_Use_K16, MMA_Atom_Arch_16x64x16, MMA_Atom_Arch_16x64x32 >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype( composition(Swizzle<3, 3, 5>{}, Layout>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; // using SmemLayoutAtomV = decltype( // composition(Swizzle<1, 3, 3>{}, // Layout, Int<32>>, // Stride, _1>>{})); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); static constexpr int kSmemKVSize = Share_K_V_smem ? kSmemKSize : 2 * kSmemKSize; // 写出过lds有性能提升,128*128*2 = 32768 static constexpr int kSmemSize = kHeadDim == kHeadDimV ? 32768 : 32768; // static constexpr int kSmemSize = Is_Q_use_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read // from how many rows does each thread have to fetch static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); // Here we assign a contiguous tile to each thread, rather than a 1x8 row every // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread // do not cross a page boundary. This way, each thread need only fetch 1 page index per // mainloop iteration. R>udimentary testing shows no slowdown. using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; #if 0 template > struct Flash_fwd_kernel_16x64_prefetch_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockNSmem = (kStages + 1) * 16; static constexpr int kBlockKSmem = (kStages + 1) * 32; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<(kStages + 1)*kHeadDimV>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; #else template > struct Flash_fwd_kernel_16x64_prefetch_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_; // struct __align__(128) Shared { // // 4 warps in a warpgroup vote to an atomic variable in shared memory // uint32_t skip_softmax_votes; // }; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); // static constexpr int kSmemSize = ENABLE_SKIP_SOFTMAX ? std::max(kSmemKSize, kSmemVSize) + sizeof(Shared) : std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; #endif template > struct Flash_fwd_kernel_16x64_prefetch_traits_fp8 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; // static constexpr int kBlockNSmem = (kStages + 1) * 32; // static constexpr int kBlockKSmem = (kStages + 1) * 32; using MMA_Atom_Arch_16x64 = MMA_Atom;//qk using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom;//pv using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutV{}) * sizeof(Element); static constexpr int kSmemSize = 16384; //static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_fwd_kernel_16x64_prefetch_traits_dim96 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*kHeadDimV>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_fwd_kernel_16x64_prefetch_traits_dim64 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages_GEMM0 = kHeadDim / 32; static constexpr int kStages_GEMM1 = kBlockN / 16; static constexpr int kBlockNSmem = (kStages_GEMM1) * 16; static constexpr int kBlockKSmem = (kStages_GEMM0) * 32; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<(kStages_GEMM1)*kHeadDimV>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = kSmemKSize + kSmemVSize; using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using SmemLayoutAtomO = decltype( composition(Swizzle<3, 3, 3>{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; }; template > struct Flash_fwd_kernel_16x64_prefetch_mla_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 0 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_fwd_kernel_16x64_prefetch_mla_traits_fp8 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64;//256 static constexpr int kBlockM = kBlockM_;//128 static constexpr int kBlockN = kBlockN_;//64 static constexpr int kHeadDim = kHeadDim_;//192 static constexpr int kHeadDimV = kHeadDimV_;//128 static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_;//3 static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;//64 static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);//128 static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = MMA_Atom; using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; //using MMA_Atom_Arch_16x32 = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomK = Layout, Int<192>>, Stride, _1>>;//128,192 using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<192>>{}));//64,192 using SmemLayoutK = Layout, Int<64>>, Stride, _1>>;//192,64 // using SmemLayoutAtomV = decltype(composition( // Swizzle{}, // Layout, Int<64>>, Stride, _1>>{})); // using SmemLayoutV = decltype(tile_to_shape( // SmemLayoutAtomV{}, // Shape, Int>{}));//64,128 using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); static constexpr int kSmemKVSize = 16384; //static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); //static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element);//128*128*1 // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);//16 static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad;//64/16=4 static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 0 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_mla_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomK{},Shape, Int<128>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype(composition(Swizzle{}, Layout, Int>,Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits_dim256 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomK = Layout, Int<256>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int<256>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; // using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; // using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); // using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); // using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; // using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); // using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutAtomO = decltype(composition(Swizzle{}, Layout, Int>,Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = MMA_Atom; // using MMA_Atom_Arch_16x64x16 = MMA_Atom; using MMA_Atom_Arch_16x64x32_Blayout = MMA_Atom; using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; // using MMA_Atom_Arch_16x64_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64_BLayout_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64x32_NN = MMA_Atom; using MMA_Atom_Arch_16x64_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32_NN = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // using TiledMma16x64x16 = TiledMMA< // MMA_Atom_Arch_16x64x16, // Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group // typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64x32BLayout = TiledMMA< MMA_Atom_Arch_16x64x32_Blayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64_LIT = TiledMMA< MMA_Atom_Arch_16x64_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64_Blayout_LIT = TiledMMA< MMA_Atom_Arch_16x64_BLayout_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x32_NN = TiledMMA< MMA_Atom_Arch_16x64x32_NN, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); // using SmemLayoutK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementO); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize + 2*size(SmemLayoutV{}) * sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementO); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim192 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = MMA_Atom; // using MMA_Atom_Arch_16x64x16 = MMA_Atom; using MMA_Atom_Arch_16x64x32_Blayout = MMA_Atom; using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; // using MMA_Atom_Arch_16x64_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64_BLayout_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64x32_NN = MMA_Atom; using MMA_Atom_Arch_16x64_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32_NN = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // using TiledMma16x64x16 = TiledMMA< // MMA_Atom_Arch_16x64x16, // Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group // typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64x32BLayout = TiledMMA< MMA_Atom_Arch_16x64x32_Blayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64_LIT = TiledMMA< MMA_Atom_Arch_16x64_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64_Blayout_LIT = TiledMMA< MMA_Atom_Arch_16x64_BLayout_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x32_NN = TiledMMA< MMA_Atom_Arch_16x64x32_NN, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<192>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<192>>{})); // using SmemLayoutK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementO); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementO); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim256 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = MMA_Atom; // using MMA_Atom_Arch_16x64x16 = MMA_Atom; using MMA_Atom_Arch_16x64x32_Blayout = MMA_Atom; using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; // using MMA_Atom_Arch_16x64_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64_BLayout_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64x32_NN = MMA_Atom; using MMA_Atom_Arch_16x64_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32_NN = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // using TiledMma16x64x16 = TiledMMA< // MMA_Atom_Arch_16x64x16, // Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group // typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64x32BLayout = TiledMMA< MMA_Atom_Arch_16x64x32_Blayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64_LIT = TiledMMA< MMA_Atom_Arch_16x64_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64_Blayout_LIT = TiledMMA< MMA_Atom_Arch_16x64_BLayout_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x32_NN = TiledMMA< MMA_Atom_Arch_16x64x32_NN, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<256>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<256>>{})); // using SmemLayoutK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementO); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementO); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8 : public Base { using Element = typename Base::Element; using ElementKV = elem_type_kv; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); // using SmemLayoutK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize + 2*size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim64 : public Base { using Element = typename Base::Element; using ElementKV = elem_type_kv; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<64>>{})); // using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim256 : public Base { using Element = typename Base::Element; using ElementKV = elem_type_kv; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<256>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<256>>{})); // using SmemLayoutK = Layout, Int<256>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim192 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<192>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<192>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim256 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<256>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<256>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim64 : public Base { using Element = typename Base::Element; using ElementO = elem_type_o; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = MMA_Atom; // using MMA_Atom_Arch_16x64x16 = MMA_Atom; using MMA_Atom_Arch_16x64x32_Blayout = MMA_Atom; using MMA_Atom_Arch_16x64_BLayout = MMA_Atom; // using MMA_Atom_Arch_16x64_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64_BLayout_LIT = MMA_Atom; // using MMA_Atom_Arch_16x64x32_NN = MMA_Atom; using MMA_Atom_Arch_16x64_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout_LIT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x32_NN = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // using TiledMma16x64x16 = TiledMMA< // MMA_Atom_Arch_16x64x16, // Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group // typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64x32BLayout = TiledMMA< MMA_Atom_Arch_16x64x32_Blayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x64_LIT = TiledMMA< MMA_Atom_Arch_16x64_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64_Blayout_LIT = TiledMMA< MMA_Atom_Arch_16x64_BLayout_LIT, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x32_NN = TiledMMA< MMA_Atom_Arch_16x64x32_NN, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<64>>{})); // using SmemLayoutK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementO); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize + 2*size(SmemLayoutV{}) * sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementO); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_ws_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<128>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<128>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = 65536; // static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 512 == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<64>>{})); using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = kSmemKSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_splitkv_vllm_kvcache_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64x16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64x16 = TiledMMA< MMA_Atom_Arch_16x64x16, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomK = decltype( composition(Swizzle<3, 3, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = decltype( composition(Swizzle<3, 3, 4>{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomV = Layout, Int<64>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<64>>{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; // using SmemLayoutAtomV = Layout, Int<16>>, Stride, _1>>; // using SmemLayoutV = decltype(tile_to_shape( // SmemLayoutAtomV{}, // Shape, Int<64>>{})); // using SmemLayoutVtransposed = decltype( // composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element); static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); // static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize); static constexpr int kSmemSize = (kBlockN * kHeadDim + kBlockN * kHeadDimV) * 2;//(kv+v)*B static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); using GmemTiledCopyQKVPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout, _8>, Stride<_8, _1>>{})); using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load using GmemTiledCopyRotcossinPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinContPaged = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; template > struct Flash_fwd_kernel_16x64_prefetch_traits_dim256 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_fwd_kernel_16x64_prefetch_traits_dim512 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = true; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kBlockN % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); static constexpr int kStages = kStages_; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x64 = TiledMMA< MMA_Atom_Arch_16x64, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; using TiledMma16x32 = TiledMMA< MMA_Atom_Arch_16x32, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomK = Layout, Int<32>>, Stride, _1>>; using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomK{}, Shape, Int>{})); using SmemLayoutAtomV = Layout, Int<32>>, Stride, _1>>; using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKsplit{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVsplit{}) * sizeof(Element); static constexpr int kSmemSize = std::max(kSmemKSize, kSmemVSize); using GmemLayoutAtom = Layout, Stride< _8, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomQdO = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); // using SmemLayoutAtomKV = decltype( // composition(Swizzle{}, // Layout, Int>, // Stride, _1>>{})); using SmemLayoutAtomKV = decltype( composition(Swizzle<2, 4, 2>{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); // static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. // static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; // using SmemLayoutAtomPdS = decltype( // composition(Swizzle{}, // Layout, Int>, // Stride, _1>>{})); using SmemLayoutAtomPdS = decltype( composition(Swizzle<2, 4, 2>{}, Layout, Int<4>>, Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); // using SmemLayoutAtomdKV = decltype( // composition(Swizzle{}, // Layout>, // Stride, _1>>{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle<2, 4, 2>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomdQ = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemdSSize = (Is_V_in_regs ? size(SmemLayoutKV{}): size(SmemLayoutPdS{})) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); // wangaq debug for dq static constexpr int kSmemSize1rowblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize : kSmemKVSize / 2); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy(Copy_Atom{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store }; template > struct Flash_bwd_kernel_dq_traits : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDim; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using TiledMmaSdP = typename Base::TiledMmaSdP; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomQdO = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); // using SmemLayoutAtomKV = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int>, // Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); // static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. // static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); // using SmemLayoutAtomPdS = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomdQ = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * 2 * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize1rowblock = Share_Q_K_smem ? std::max(kSmemQdOSize, kSmemKVSize) : kSmemQdOSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store #if 0 using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store #else using GmemLayoutAtomdQ = Layout, Stride< _4, _1>>; using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store #endif }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Flash_bwd_kernel_trans_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomQdO = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); // using SmemLayoutAtomKV = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); // using SmemLayoutAtomPdS = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<2, 4, 2>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemdSSize = (Is_V_in_regs ? size(SmemLayoutKV{}): size(SmemLayoutPdS{})) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); static constexpr int kSmemSizeTrans1colblock = kSmemQdOSize + kSmemKVSize + kSmemdQSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy(Copy_Atom{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store using GmemLayoutAtomdKVaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _64, _1>> >; using GmemTiledCopydKVaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdKVaccum{}, Layout>{})); // Val layout, 4 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); #if 0 using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int, _1>>, typename Base::ValLayoutMNK>; #else using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom>, MMA_Atom> >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int, _1>>, typename Base::ValLayoutMNK>; #endif using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; // 打开swizzle #if 1 using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); #else using SmemLayoutAtomQdO = decltype( Layout>, Stride, _1>>{}); #endif // using SmemLayoutAtomQdO = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); #if 0 using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #else using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #endif // using SmemLayoutAtomKV = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; #if 0 using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #else using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #endif // using SmemLayoutAtomPdS = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<2, 4, 2>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemdSSize = (Is_V_in_regs ? size(SmemLayoutKV{}): size(SmemLayoutPdS{})) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); #if 0 static constexpr int kSmemSizeTrans1colblock = kSmemQdOSize + kSmemKVSize; #else static constexpr int kSmemSizeTrans1colblock = std::max(kSmemQdOSize, kSmemKVSize); // kSmemQdOSize + kSmemKVSize; #endif static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 0 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomKV = Layout, Stride<_4, _1>>; using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomKV{}, Layout>>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy(Copy_Atom{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store using GmemLayoutAtomdKVaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _64, _1>> >; using GmemTiledCopydKVaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdKVaccum{}, Layout>{})); // Val layout, 4 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_prefetch_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static constexpr int kStages = kStages_; static constexpr int kBlockNSmem = kStages * 16; static constexpr int kBlockKSmem = kStages * 32; static_assert(kBlockM % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using SmemLayoutAtomQGemm0 = Layout, Int>, Stride, _1>>; using SmemLayoutAtomdOGemm0 = Layout, Int>, Stride, _1>>; using SmemLayoutQGemm0 = decltype(tile_to_shape(SmemLayoutAtomQGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutdOGemm0 = decltype(tile_to_shape(SmemLayoutAtomdOGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutQ = Layout, Int<64>>, Stride, _1>>; using SmemLayoutdO = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomQdOGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutQGemm1transposed = decltype(composition(SmemLayoutQGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQGemm1transposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQGemm1transposed{})); using SmemLayoutQsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<(kStages + 1)*kHeadDim>>{})); using SmemLayoutQtransSplit = decltype(composition(SmemLayoutQsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutdOGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutdOGemm1transposed = decltype(composition(SmemLayoutdOGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutdOGemm1transposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdOGemm1transposed{})); using SmemLayoutdOsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<(kStages + 1)*kHeadDimV>>{})); using SmemLayoutdOtransSplit = decltype(composition(SmemLayoutdOsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<4, 2, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemLayoutdVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQGemm0{}) * sizeof(Element); static constexpr int kSmemdOSize = size(SmemLayoutdOGemm0{}) * sizeof(Element); static constexpr int kSmemOffset = kHeadDim == 192 ? 4096 : 0; static constexpr int kSmemPrefetchSize = kHeadDim == 64 ? kSmemQSize + kSmemdOSize : std::max(kSmemQSize, kSmemdOSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 8; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static constexpr int kStages = kStages_; static_assert(kBlockM % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using SmemLayoutAtomQdOGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdO = decltype(tile_to_shape(SmemLayoutAtomQdOGemm0{}, Shape, Int>{})); using SmemLayoutAtomQdOGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutQGemm1transposed = decltype(composition(SmemLayoutQGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<4*kHeadDim>>{})); using SmemLayoutQtransSplit = decltype(composition(SmemLayoutQsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutdOGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutdOGemm1transposed = decltype(composition(SmemLayoutdOGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutdOsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<4*kHeadDimV>>{})); using SmemLayoutdOtransSplit = decltype(composition(SmemLayoutdOsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<4, 2, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemLayoutdVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQdO{}) * sizeof(Element); static constexpr int kSmemdOSize = size(SmemLayoutQdO{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = kHeadDim == 64 ? kSmemQSize + kSmemdOSize : std::max(kSmemQSize, kSmemdOSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 8; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static constexpr int kStages = kStages_; static_assert(kBlockM % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1, _1>>, typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomQdOGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdOGemm0 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm0{}, Shape, Int>{})); using SmemLayoutAtomQdOGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdOGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutQdOGemm1transposed = decltype(composition(SmemLayoutQdOGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutQdOtransSplit = decltype(composition(SmemLayoutQdOsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<4, 2, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemLayoutdVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; static constexpr int kSmemQdOSize = size(SmemLayoutQdOGemm0{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = kSmemQdOSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 8; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512 : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static constexpr int kStages = kStages_; static_assert(kBlockM % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1, _1>>, typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomQdOGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdOGemm0 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm0{}, Shape, Int>{})); using SmemLayoutAtomQdOGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdOGemm1 = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutQdOGemm1transposed = decltype(composition(SmemLayoutQdOGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOsplit = decltype(tile_to_shape(SmemLayoutAtomQdOGemm1{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutQdOtransSplit = decltype(composition(SmemLayoutQdOsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<4, 2, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemLayoutdVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; static constexpr int kSmemQdOSize = size(SmemLayoutQdOGemm0{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = kSmemQdOSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 8; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_trans_16x64_prefetch_mla_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static constexpr int kStages = kStages_; static_assert(kBlockM % 64 == 0); static_assert(kHeadDim % 32 == 0); static_assert(kHeadDimV % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMmadKV = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1, _1>>, typename Base::ValLayoutMNK>; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using SmemLayoutAtomQdOGemm0 = Layout, Int>, Stride, _1>>; using SmemLayoutQdOGemm0 = decltype(tile_to_shape( SmemLayoutAtomQdOGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutQdO = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomQdOGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutQdOGemm1 = decltype(tile_to_shape( SmemLayoutAtomQdOGemm1{}, Shape, Int>{})); using SmemLayoutQdOGemm1transposed = decltype( composition(SmemLayoutQdOGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQGemm1Tail = decltype(tile_to_shape( SmemLayoutAtomQdOGemm1{}, Shape, Int<64>>{})); using SmemLayoutQGemm1TailTransposed = decltype( composition(SmemLayoutQGemm1Tail{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutAtomdKVStore = decltype( composition(Swizzle<4, 2, 4>{}, Layout>, Stride, _1>>{})); using SmemLayoutdKStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemLayoutdVStore = decltype(tile_to_shape( SmemLayoutAtomdKVStore{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; static constexpr int kSmemQdOSize = size(SmemLayoutQdOGemm0{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = kSmemQdOSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = 8; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); #if 1 using GmemLayoutAtom = Layout, Int>, Stride, _1>>; #else using GmemLayoutAtom = Layout, Stride< _4, _1>>; #endif using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 vals per store }; ////////////////////////////////////////////////////////////////////////////////////////////// template > struct Flash_bwd_kernel_dq_16x64_traits : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); #if 0 using TiledMmaSdP = typename Base::TiledMmaSdP; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch_FOR_GEMM1, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; #else using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom>, MMA_Atom> >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; #endif using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutV = Layout, Int<64>>, Stride, _1>>; // using SmemLayoutAtomQdO = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); #if 0 using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #else using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); #endif // using SmemLayoutAtomKV = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int>, // Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); // static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. // static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); // using SmemLayoutAtomPdS = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout, Int<4>>, // Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); // using SmemLayoutAtomdQ = decltype( // composition(Swizzle<2, 4, 2>{}, // Layout>, // Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * 2 * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize1rowblock = Share_Q_K_smem ? std::max(size(SmemLayoutdQ{}) * 2, kSmemKVSize) : kSmemQdOSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemLayoutAtomdQ = Layout, Int<4>>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Flash_bwd_kernel_dq_16x64_prefetch_traits : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kStages = kStages_; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); #if 0 #else using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64_BLayout = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma16x64BLayout = TiledMMA< MMA_Atom_Arch_16x64_BLayout, Layout,_1,_1>>, typename Base::ValLayoutMNK>; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; #endif using SmemLayoutAtomdQ = Layout, Int<32>>, Stride, _1>>; using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemLayoutV = Layout, Int<64>>, Stride, _1>>; using SmemLayoutK = Layout, Int<64>>, Stride, _1>>; using SmemLayoutAtomKGemm0 = Layout, Int>, Stride, _1>>; using SmemLayoutAtomVGemm0 = Layout, Int>, Stride, _1>>; using SmemLayoutKGemm0 = decltype(tile_to_shape( SmemLayoutAtomKGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutVGemm0 = decltype(tile_to_shape( SmemLayoutAtomVGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKGemm1 = decltype( composition(Swizzle<0, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutKGemm1 = decltype(tile_to_shape( SmemLayoutAtomKGemm1{}, make_shape(Int{}, Int{}))); using SmemLayoutKGemm1transposed = decltype( composition(SmemLayoutKGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKGemm1transposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKGemm1transposed{})); static constexpr int kSmemKSize = size(SmemLayoutKGemm0{}) * sizeof(Element); static constexpr int kSmemVSize = size(SmemLayoutVGemm0{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = kHeadDim == 64 ? kSmemKSize + kSmemVSize : std::max(kSmemKSize, kSmemVSize); // static constexpr int kSmemPrefetchSize = std::max(std::max(kSmemKSize, kSmemVSize), kSmemdQSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemThreadsPerRow = 8; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); using GmemLayoutAtomdQ = Layout, Int<4>>, Stride, _1>>; using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96 : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kStages = kStages_; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; using SmemLayoutAtomKVGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKVGemm0 = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutKVGemm0Split = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, Shape, Int<128>>{})); using SmemLayoutAtomKGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKGemm1 = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, make_shape(Int{}, Int{}))); using SmemLayoutKGemm1transposed = decltype(composition(SmemLayoutKGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, Shape, Int<4*kHeadDim>>{})); using SmemLayoutKtransSplit = decltype(composition(SmemLayoutKsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKVGemm0Split{}) * sizeof(Element); static constexpr int kSmemKtSize = size(SmemLayoutKtransSplit{}) * sizeof(Element); static constexpr int kSmemOffset = 3072; static constexpr int kSmemOffsetSize = kSmemOffset * sizeof(Element); static constexpr int kSmemPrefetchSize = std::max(kSmemKSize, kSmemKtSize + kSmemOffsetSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemThreadsPerRow = 8; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); using GmemLayoutAtomdQ = Layout, Int<4>>, Stride, _1>>; using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256 : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kStages = kStages_; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int<1>, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int<1>, _1>>, typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomKVGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKVGemm0 = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutKVGemm0Split = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, Shape, Int<128>>{})); using SmemLayoutAtomKGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKGemm1 = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, make_shape(Int{}, Int{}))); using SmemLayoutKGemm1transposed = decltype(composition(SmemLayoutKGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutKtransSplit = decltype(composition(SmemLayoutKsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKVGemm0Split{}) * sizeof(Element); static constexpr int kSmemKtSize = size(SmemLayoutKtransSplit{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = std::max(kSmemKSize, kSmemKtSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemThreadsPerRow = 8; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); using GmemLayoutAtomdQ = Layout, Int<4>>, Stride, _1>>; using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store }; template > struct Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512 : public Flash_bwd_kernel_traits { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kStages = kStages_; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static constexpr int kHeadDimV = kHeadDimV_; static_assert(kHeadDim % 32 == 0); using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch_16x64, Layout, Int<1>, _1>>, typename Base::ValLayoutMNK>; using TiledMmadQ = TiledMMA< MMA_Atom_Arch_16x32, Layout, Int<1>, _1>>, typename Base::ValLayoutMNK>; static constexpr uint32_t LayoutBlock = 64; static constexpr uint32_t LayoutDim = 128; using SmemLayoutAtomKVGemm0 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKVGemm0 = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, make_shape(Int{}, Int{}))); using SmemLayoutKVGemm0Split = decltype(tile_to_shape(SmemLayoutAtomKVGemm0{}, Shape, Int<128>>{})); using SmemLayoutAtomKGemm1 = Layout, Int<32>>, Stride, _1>>; using SmemLayoutKGemm1 = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, make_shape(Int{}, Int{}))); using SmemLayoutKGemm1transposed = decltype(composition(SmemLayoutKGemm1{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKsplit = decltype(tile_to_shape(SmemLayoutAtomKGemm1{}, Shape, Int<4*LayoutDim>>{})); using SmemLayoutKtransSplit = decltype(composition(SmemLayoutKsplit{}, make_layout(Shape, Int<16>>{}, GenRowMajor{}))); static constexpr int kSmemKSize = size(SmemLayoutKVGemm0Split{}) * sizeof(Element); static constexpr int kSmemKtSize = size(SmemLayoutKtransSplit{}) * sizeof(Element); static constexpr int kSmemPrefetchSize = std::max(kSmemKSize, kSmemKtSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemThreadsPerRow = 8; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); using GmemLayoutAtomdQ = Layout, Int<4>>, Stride, _1>>; using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQ{}, Layout>{})); // Val layout, 8 vals per store };