#pragma once #include #include #include #include using namespace cute; #include "utils.h" #include "softmax.h" #include "static_switch.h" #include "flash_mla.h" #if 1 template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; using Element_O = Element; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); static constexpr int kStages = kHeadDimV_ / 32 - 2; // static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = 3; static constexpr int kBlockKSmem = (kStages + 1) * 32; static constexpr int kBlockNSmem = (kStages + 1) * 16; using ValLayoutMNK = Layout>; #if defined(__gfx936__) || defined(__gfx938__) using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; #elif defined(__gfx928__) using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; #endif using MMA_Atom_Arch_1 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // using SmemLayoutRow = Layout>, Stride<_4, _1>>; using SmemLayoutRow = Layout, Stride<_1>>; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutAtomK_temp = decltype(composition( Swizzle{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutAtomK = SmemLayoutAtomK_temp; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<16 * 32>>{})); using SmemLayoutK_temp = decltype(tile_to_shape( SmemLayoutAtomK_temp{}, Shape, Int<15 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); // using SmemLayoutAtomP = Layout, Int>, Stride, _1>>; // using SmemLayoutAtomP = decltype(composition( // Swizzle<4, 2, 4>{}, // Layout, Int<64>>, Stride, _1>>{})); // using SmemLayoutP = decltype(tile_to_shape( // SmemLayoutAtomP{}, // Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); // using SmemLayoutAtomQ = // Layout, Int<32>>, Stride, _1>>; using SmemLayoutAtomQ = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomO = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using GmemLayoutAtomK = Layout, Stride< _4, _1>>; using GmemTiledCopyK = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomK{}, Layout>{})); #if 1 using GmemTiledCopyV = GmemTiledCopyK; #else using GmemLayoutAtomV = Layout, Stride< _16, _1>>; using GmemTiledCopyV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomV{}, Layout>{})); #endif using GmemLayoutAtomO = Layout, Stride< _16, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _16, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; template struct Flash_fwd_kernel_traits_mla_tp1 { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); using ValLayoutMNK = Layout>; using MMA_Atom_Arch_16_16_32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_16_16_32 = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_16_16_32_for_copy = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch_16x32_NT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32_NT, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<18 * 32>>{})); using SmemLayoutRow = Layout, Stride<_1>>; using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<512>>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int<32>>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); }; template struct Flash_fwd_kernel_traits_mla_kvfp8 { using Element = elem_type; using Element_O = true_elem_type; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); static constexpr int kStages = kHeadDimV_ / 32 - 2; // static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = 3; static constexpr int kBlockKSmem = (kStages + 1) * 32; static constexpr int kBlockNSmem = (kStages + 1) * 16; using MMA_Atom_Arch_16_16_32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_int8 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16x64 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using MMA_Atom_Arch_16_32_16 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; // using MMA_Atom_Arch_16x32_int8 = std::conditional_t< // std::is_same_v, // MMA_Atom, // MMA_Atom // >; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_16_16_32 = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_int8 = TiledMMA< MMA_Atom_Arch_int8, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x64, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_O_16_32_16 = TiledMMA< MMA_Atom_Arch_16_32_16, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // using TiledMma_O_int8 = TiledMMA< // MMA_Atom_Arch_16x32_int8, // Layout, _1>>, // 1x4x1 or 1x8x1 thread group // ValLayoutMNK>; using SmemLayoutRow = Layout, Stride<_1>>; // using SmemLayoutRow = Layout>, Stride<_4, _1>>; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 32>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<8 * 32>>{})); // using SmemLayoutAtomP = decltype(composition( // Swizzle<4, 2, 4>{}, // Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomV_fp8 = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutV_fp8 = decltype(tile_to_shape( SmemLayoutAtomV_fp8{}, Shape, Int<512>>{})); using SmemLayoutVtransposed_fp8 = decltype( composition(SmemLayoutV_fp8{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle_fp8 = decltype(get_nonswizzle_portion(SmemLayoutVtransposed_fp8{})); using SmemLayoutAtomQ = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomO = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using GmemLayoutAtomK = Layout, Stride< _4, _1>>; using GmemTiledCopyK = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomK{}, Layout>{})); #if 1 using GmemTiledCopyV = GmemTiledCopyK; #else using GmemLayoutAtomV = Layout, Stride< _16, _1>>; using GmemTiledCopyV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomV{}, Layout>{})); #endif using GmemLayoutAtomO = Layout, Stride< _16, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _16, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; template struct Flash_fwd_kernel_traits_mla_kvfp8_TP1 { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); using ValLayoutMNK = Layout>; using MMA_Atom_Arch_16_16_32 = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_16_16_32 = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using TiledMma_16_16_32_for_copy = TiledMMA< MMA_Atom_Arch_16_16_32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using MMA_Atom_Arch_16x32_NT = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32_NT, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; using SmemLayoutAtomK = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Int<32>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<18 * 32>>{})); using SmemLayoutRow = Layout, Stride<_1>>; using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int<512>>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int<32>>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); }; namespace flash { using namespace cute; template struct SharedStorageMLATP1 { union { struct { cute::array_aligned> smem_k; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; // cute::array_aligned> smem_v; // Double buffer // cute::array_aligned> smem_p; // cute::array_aligned> smem_row_sum; // cute::array_aligned> smem_row_max; }; }; }; template struct SharedStorageMLA { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_o; // cute::array_aligned> smem_p; // cute::array_aligned> smem_row_sum; // cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; // struct { // cute::array_aligned> smem_max; // cute::array_aligned> smem_sum; // cute::array_aligned> smem_o; // }; }; }; template struct SharedStorageMLAFp8 { union { struct { cute::array_aligned> smem_v; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_o; // cute::array_aligned> smem_p; // cute::array_aligned> smem_row_sum; // cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_q; }; }; }; template struct SharedStorageMLAFp8_TP1 { union { struct { cute::array_aligned> smem_k; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; // cute::array_aligned> smem_v; // Double buffer // cute::array_aligned> smem_p; // cute::array_aligned> smem_row_sum; // cute::array_aligned> smem_row_max; }; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// #if 1 template __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); // Epilogue const int split_offset = __ldg(params.num_splits_ptr + bidb); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor lse = softmax.template normalize_softmax_lse(tOrO, sRow_sum_reduce_buffer, params.scale_softmax); // if (thread0()) // { // printf(" %d %.2f\n", (int)size(lse), lse(0)); // } using ElementO = std::conditional_t; Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning using SmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor rO = flash::convert_type(tOrO); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) // __syncthreads(); cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); using GmemTiledCopyO = std::conditional_t; GmemTiledCopyO gmem_tiled_copy_Oaccum; // if (thread0()) // { // print(gmem_tiled_copy_Oaccum); print("\n"); // } auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); // if (tidx >= kNThreadsS) { return; } Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) // Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); // CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM ); } #endif template __forceinline__ __device__ void store_fp8(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element_O; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); // Epilogue const int split_offset = __ldg(params.num_splits_ptr + bidb); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor lse = softmax.template normalize_softmax_lse(tOrO, sRow_sum_reduce_buffer, params.scale_softmax); // if (thread0()) // { // printf("Split = %d %d %.2f\n",Split, (int)size(lse), lse(0)); // } using ElementO = std::conditional_t; // Partition sO to match the accumulator partitioning using SmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum >; Tensor rO = flash::convert_type(tOrO); // { // int tidx = threadIdx.x; // if (block0()&& tidx < 64) { // printf("tidx = %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", // tidx, // float(rO(0)), float(rO(1)), float(rO(2)), float(rO(3)), // float(rO(4)), float(rO(5)), float(rO(6)), float(rO(7)), // float(rO(8)), float(rO(9)), float(rO(10)), float(rO(11)), // float(rO(12)), float(rO(13)), float(rO(14)), float(rO(15)) // ); // print("acc_o\n"); print(acc_o); print("\n"); // } // } // __syncthreads(); // cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) // Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); // CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < size<1>(rO); m++) { const int row = get<0>(taccOcO(0, m, 0)); if (row < params.seqlen_q - m_block * kBlockM) { for (int n = 0; n < size<2>(rO); n++) { col = (tidx % 64 / 16) * 2 + warpid * 64 + n * 256; for (int ei = 0; ei < 16; ei +=2) { gOaccum(row, col) = rO(ei, m, n); gOaccum(row, col+1) = rO(ei+1, m, n); col += 8; } } } } } } template __forceinline__ __device__ void store_fp8_tp1(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; int warpid = tidx / 64; typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); // Epilogue const int split_offset = __ldg(params.num_splits_ptr + bidb); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor lse = softmax.template normalize_softmax_lse(tOrO, sRow_sum_reduce_buffer, params.scale_softmax); // if (thread0()) // { // printf("Split = %d %d %.2f\n",Split, (int)size(lse), lse(0)); // } using ElementO = std::conditional_t; // Partition sO to match the accumulator partitioning // using SmemTiledCopyO = std::conditional_t< // !Split, // typename Kernel_traits::SmemCopyAtomO, // typename Kernel_traits::SmemCopyAtomOaccum // >; Tensor rO = flash::convert_type(tOrO); // { // int tidx = threadIdx.x; // if (block0()&& tidx < 64) { // printf("tidx = %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", // tidx, // float(rO(0)), float(rO(1)), float(rO(2)), float(rO(3)), // float(rO(4)), float(rO(5)), float(rO(6)), float(rO(7)), // float(rO(8)), float(rO(9)), float(rO(10)), float(rO(11)), // float(rO(12)), float(rO(13)), float(rO(14)), float(rO(15)) // ); // print("acc_o\n"); print(acc_o); print("\n"); // } // } // __syncthreads(); // cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) // Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); // CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO(0)) == 0 && warpid / 4 == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { int tidx = threadIdx.x; int col = 0; for (int m = 0; m < size<1>(rO); m++) { const int row = get<0>(taccOcO(0, m, 0)); if (row < params.seqlen_q - m_block * kBlockM) { for (int n = 0; n < size<2>(rO); n++) { col = (tidx % 64 / 16) + (warpid / 4 ) * 32 + n * 64; for (int ei = 0; ei < 8; ei ++) { gOaccum(row, col) = rO(ei, m, n); // gOaccum(row, col+1) = rO(ei+1, m, n); col += 4; } } } } } } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx936(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; // const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), // Shape, Int>{}, // make_stride(params.q_row_stride, _1{})); const index_t row_offset_q_nope = bidb * params.q_nope_batch_stride + m_block * kBlockM * params.q_nope_row_stride + bidh * params.q_nope_head_stride; Tensor gQ_nope = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_nope_ptr) + row_offset_q_nope), Shape, Int<512>>{}, make_stride(params.q_nope_row_stride, _1{})); const index_t row_offset_q_pe = bidb * params.q_pe_batch_stride + m_block * kBlockM * params.q_pe_row_stride + bidh * params.q_pe_head_stride; Tensor gQ_pe = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_pe_ptr) + row_offset_q_pe), Shape, Int<64>>{}, make_stride(params.q_pe_row_stride, _1{})); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); // if (thread0()) // { // printf("sv sp srow = %p %p %p \n", sV.data().get(), sP.data().get(), sRow_max_reduce_buffer.data().get()); // } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); #if 1 // lds_direct_copy(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // // if (thread0()) // // { // // for (int i = 0; i < 16; i++) // // { // // for (int j = 0; j < 576; j++) // // { // // printf(" %.2f ", float(sQ(i, j))); // // } // // printf("\n"); // // } // // } // auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); // asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); // asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17)); // __syncthreads(); //过lds读取q, 由于q是4个warp共用的 // typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; // auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); // Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); // Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); // if (threadIdx.x < 128) // flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, // params.seqlen_q - m_block * kBlockM); // __syncthreads(); // auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); // __syncthreads(); // #else // auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); // auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); // Tensor tSrQ = thr_mma.partition_fragment_A(gQ); // Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); // flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, // params.seqlen_q - m_block * kBlockM); // __syncthreads(); typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ_pe = gmem_thr_copy_Q.partition_S(gQ_pe); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ_pe = make_identity_tensor(make_shape(size<0>(gQ_pe), size<1>(gQ_pe))); Tensor tQcQ_pe = gmem_thr_copy_Q.partition_S(cQ_pe); Tensor tQpQ_pe = make_tensor(make_shape(size<2>(tQgQ_pe))); Tensor tQgQ_nope = gmem_thr_copy_Q.partition_S(gQ_nope); Tensor cQ_nope = make_identity_tensor(make_shape(size<0>(gQ_nope), size<1>(gQ_nope))); Tensor tQpQ_nope = make_tensor(make_shape(size<2>(tQgQ_nope))); // if (threadIdx.x < 128) { for (int m = 0; m < size<1>(tQgQ_nope); m++) { if (get<0>(tQcQ_pe(0, m, 0)) < params.seqlen_q - m_block * kBlockM) { for (int k = 0; k < 512/64; k++) { cute::copy(gmem_tiled_copy_Q, tQgQ_nope(_, m, k), tQsQ(_, m, k)); } for (int k = 0; k < 1; k++) { cute::copy(gmem_tiled_copy_Q, tQgQ_pe(_, m, k), tQsQ(_, m, 8)); } } } } __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #endif #if 0 auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = smem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tSgK))); #else typename Kernel_traits::GmemTiledCopyK gmem_tiled_copy_K; auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx); Tensor tKgK = gmem_thr_copy_K.partition_S(gK); Tensor tKsK = gmem_thr_copy_K.partition_D(sK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = gmem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tKgK))); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); // if (thread0()) // { // print("tSgK\n"); print(tSgK); print("\n"); // print("tKgK\n"); print(tKgK); print("\n"); // } Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK); Tensor tKpK_smem = make_tensor(make_shape(size<2>(tSgK))); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); #endif typename Kernel_traits::GmemTiledCopyV gmem_tiled_copy_V; auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx); Tensor tVgV = gmem_thr_copy_V.partition_S(gV); Tensor tVsV = gmem_thr_copy_V.partition_D(sV); // if (0 || thread(64)) // { // print("tksk "); print(tKsK); print("\n"); // print("tVsV "); print(tVsV); print("\n"); // } Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV))); Tensor tVcV = gmem_thr_copy_V.partition_S(cV); Tensor tVpV = make_tensor(make_shape(size<2>(tVgV))); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; constexpr static int k0_lds_loops = 15; constexpr static int k0_loops = size<2>(tSrK_smem); constexpr static int k1_loops = size<2>(tOrVt); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; constexpr static int STAGE = 15; Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); #if 1 #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); // 这个也做过循环2类似的修改,但是性能不如现在的好,所以保持不变 int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < STAGE; i++) { lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 3; uint128_t buffer[BUFFER_SIZE]; buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[0], tSrK_smem, 15); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[1], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[2], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } // We have key_padding_mask so we'll need to Check_inf if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); lds_direct_copy(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); // asm_ds_write(buffer[0], tVsV, 15); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); #pragma unroll for (int i = 0; i < k1_loops; i++) { cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); } // asm volatile("s_barrier\n\t"); } #endif #if 1 #pragma unroll for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < 16; i++) { lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 2; uint128_t buffer[BUFFER_SIZE]; // buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[0], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[1], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); __builtin_amdgcn_sched_barrier(0); __ds_read_m32x16_row_col<3, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 3>(tOsVt, tOrVt_copy_view); __builtin_amdgcn_sched_barrier(0); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[0], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[1], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); // We have key_padding_mask so we'll need to Check_inf softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o); // asm volatile("s_barrier\n\t"); } #endif if (NoSplit) store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_gfx936(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, const float k_scale, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; // using Element = cutlass::half_t; using Element_O = typename Kernel_traits::Element_O; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; // const float k_scale = *reinterpret_cast(params.k_scale_ptr); // if (thread0()) // { // printf("k_scale %.2f\n", k_scale); // } // constexpr int is_scale_equal_one = true; const int tidx = threadIdx.x; const index_t row_offset_q_nope = bidb * params.q_nope_batch_stride + m_block * kBlockM * params.q_nope_row_stride + bidh * params.q_nope_head_stride; Tensor gQ_nope = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_nope_ptr) + row_offset_q_nope), Shape, Int<512>>{}, make_stride(params.q_nope_row_stride, _1{})); const index_t row_offset_q_pe = bidb * params.q_pe_batch_stride + m_block * kBlockM * params.q_pe_row_stride + bidh * params.q_pe_head_stride; Tensor gQ_pe = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_pe_ptr) + row_offset_q_pe), Shape, Int<64>>{}, make_stride(params.q_pe_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); // Tensor sQ_nope = make_tensor(make_smem_ptr(shared_storage.smem_q_nope.data()), typename Kernel_traits::SmemLayoutQ_nope{}); // Tensor sQ_pe = make_tensor(make_smem_ptr(shared_storage.smem_q_pe.data()), typename Kernel_traits::SmemLayoutQ_pe{}); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle_fp8 = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle_fp8{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); // if (thread0()) // { // printf("sv sp srow = %p %p %p \n", sV.data().get(), sP.data().get(), sRow_max_reduce_buffer.data().get()); // } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_16_16_32 tiled_mma_16_16_32; auto thr_mma_16_16_32 = tiled_mma_16_16_32.get_thread_slice(tidx); typename Kernel_traits::TiledMma_int8 tiled_mma_int8; auto thr_mma_int8 = tiled_mma_int8.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O_16_32_16 tiled_mma_o_16_32_16; auto thr_mma_o_16_32_16 = tiled_mma_o_16_32_16.get_thread_slice(tidx); // typename Kernel_traits::TiledMma_O_int8 tiled_mma_o_int8; // auto thr_mma_o_int8 = tiled_mma_o_int8.get_thread_slice(tidx); #if 1 // 过lds读取q, 由于q是4个warp共用的 // typename Kernel_traits::GmemTiledCopyQ_nope gmem_tiled_copy_Q_nope; // auto gmem_thr_copy_Q_nope = gmem_tiled_copy_Q_nope.get_thread_slice(tidx); // Tensor tQgQ_nope = gmem_thr_copy_Q_nope.partition_S(gQ_nope); // Tensor tQsQ_nope = gmem_thr_copy_Q_nope.partition_D(sQ_nope); // Tensor cQ_nope = make_identity_tensor(make_shape(size<0>(gQ_nope), size<1>(gQ_nope))); // Tensor tQcQ_nope = gmem_thr_copy_Q_nope.partition_S(cQ_nope); // Tensor tQpQ_nope = make_tensor(make_shape(size<2>(tQgQ_nope))); // typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; // auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tQgQ_pe = gmem_thr_copy_Q.partition_S(gQ_pe); // Tensor tQsQ_pe = gmem_thr_copy_Q.partition_D(sQ_pe); // Tensor cQ_pe = make_identity_tensor(make_shape(size<0>(gQ_pe), size<1>(gQ_pe))); // Tensor tQcQ_pe = gmem_thr_copy_Q.partition_S(cQ_pe); // Tensor tQpQ_pe = make_tensor(make_shape(size<2>(tQgQ_pe))); // flash::copy(gmem_tiled_copy_Q_nope, tQgQ_nope, tQsQ_nope, tQcQ_nope, tQpQ_nope, // params.seqlen_q - m_block * kBlockM); // if (threadIdx.x < 128) // { // flash::copy(gmem_tiled_copy_Q, tQgQ_pe, tQsQ_pe, tQcQ_pe, tQpQ_pe, // params.seqlen_q - m_block * kBlockM); // } // __syncthreads(); // auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tSsQ_nope = smem_thr_copy_Q.partition_S(sQ_nope); // Tensor tSsQ_pe = smem_thr_copy_Q.partition_S(sQ_pe); // Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // for (int k = 0; k < 8; k++) // { // cute::copy(smem_tiled_copy_Q, tSsQ_nope(_, _, k), tSrQ_copy_view(_, _, k)); // } // for (int k = 0; k < 1; k++) // { // cute::copy(smem_tiled_copy_Q, tSsQ_pe(_, _, k), tSrQ_copy_view(_, _, k+8)); // } // cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ_pe = gmem_thr_copy_Q.partition_S(gQ_pe); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ_pe = make_identity_tensor(make_shape(size<0>(gQ_pe), size<1>(gQ_pe))); Tensor tQcQ_pe = gmem_thr_copy_Q.partition_S(cQ_pe); Tensor tQpQ_pe = make_tensor(make_shape(size<2>(tQgQ_pe))); Tensor tQgQ_nope = gmem_thr_copy_Q.partition_S(gQ_nope); Tensor cQ_nope = make_identity_tensor(make_shape(size<0>(gQ_nope), size<1>(gQ_nope))); Tensor tQpQ_nope = make_tensor(make_shape(size<2>(tQgQ_nope))); // if (threadIdx.x < 128) // if constexpr (std::is_same_v) { for (int m = 0; m < size<1>(tQgQ_nope); m++) { if (get<0>(tQcQ_pe(0, m, 0)) < params.seqlen_q - m_block * kBlockM) { for (int k = 0; k < 512/64; k++) { cute::copy(gmem_tiled_copy_Q, tQgQ_nope(_, m, k), tQsQ(_, m, k)); } for (int k = 0; k < 1; k++) { cute::copy(gmem_tiled_copy_Q, tQgQ_pe(_, m, k), tQsQ(_, m, 8)); } } } } __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); union Fp32_storage { float fp32; uint32_t u32; }; union Fp16_storage { __fp16 fp16_v; uint16_t tmp; }; // if constexpr (std::is_same_v) { // for (int i = 0; i < size(tSrQ_copy_view); i++) // { // uint16_t tmp = tSrQ_copy_view(i).storage; // Fp32_storage fp32; // fp32.u32 = tmp << 16; // Fp16_storage fp16_t; // fp16_t.fp16_v = static_cast<__fp16>(fp32.fp32); // tSrQ_copy_view(i) = cutlass::half_t::bitcast(fp16_t.tmp); // } // } #else auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); #endif #if 0 #else auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma_16_16_32); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor tSrK_int8 = thr_mma_int8.partition_fragment_B(gK); #endif auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o_16_32_16); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle_fp8); // Tensor tOrVt_int8 = thr_mma_o_int8.partition_fragment_B(sVtNoSwizzle); // if (thread0()) // { // print("tOsVt "); print(tOsVt), print("\n"); // print("tOrVt "); print(tOrVt), print("\n"); // // print("tOrVt_int8 "); print(tOrVt_int8), print("\n"); // } constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; const auto sk_data = sK.data(); const auto sRow_max_reduce_buffer_data = sRow_max_reduce_buffer.data(); constexpr auto sk_size = size(sK); const auto sP_data = sP.data(); const auto tSsK_data = tSsK.data(); const auto tOsVt_data = tOsVt.data(); const auto gK_data = gK.data(); constexpr static int BUFFER_SIZE = 1; // uint128_t buffer[BUFFER_SIZE]; constexpr short int wait_cnt = 8; { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = n_block % 2 == 1 ? sk_data + sk_size : sk_data; // tSsK.data() = n_block % 2 == 0? tSsK_data + sk_size : tSsK_data; // tOsVt.data() = n_block % 2 ? tOsVt_data + sk_size : tOsVt_data; // if (block0() && tidx < 64) // { // printf("tidx = %d Is_causal = %d n_block = %d addr %p sk %p tSsK = %p \n", tidx, Is_causal, n_block, &(sRow_max_reduce_buffer(0)), (&sK(0, 0)), (&tSsK(0, 0, 0))); // } #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } // __syncthreads(); // if (thread0()) // { // for (int i = 0; i < 64; i++) // { // for (int j = 0; j < 512; j++) // { // printf(" %.2f ", float(sK(i, j))); // } // printf("\n"); // } // } } Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; #if 1 for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { // asm volatile("s_barrier\n\t"); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN); #if 0 #else gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); #endif // asm volatile("s_barrier\n\t"); } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { acc_s(i) *= k_scale; } if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; // We have key_padding_mask so we'll need to Check_inf if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } // __syncthreads(); { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif } } #endif #if 1 for (; n_block >= n_block_min; --n_block) { // asm volatile("s_barrier\n\t"); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - (n_block - 1) * kBlockN); #if 0 #else gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); #endif // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); // asm volatile("s_barrier\n\t"); } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { for (int i = 0; i < size(acc_s); i++) { acc_s(i) *= k_scale; } } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; // sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; // sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; // tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } // __syncthreads(); { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif // tOsVt.data() = (n_block - 1) % 2 ? tOsVt_data + sk_size : tOsVt_data; } } #endif // if (thread0()) // { // printf("NoSplit %d \n", NoSplit); // } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { for (int i = 0; i < size(acc_o); i++) { acc_o(i) *= k_scale; } } if (NoSplit) store_fp8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store_fp8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, const float k_scale, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; // const float k_scale = *reinterpret_cast(params.k_scale_ptr); // if (thread0()) // { // printf("k_scale %.2f\n", k_scale); // } // constexpr int is_scale_equal_one = true; const int tidx = threadIdx.x; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle_fp8 = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle_fp8{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); // if (thread0()) // { // printf("sv sp srow = %p %p %p \n", sV.data().get(), sP.data().get(), sRow_max_reduce_buffer.data().get()); // } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_16_16_32 tiled_mma_16_16_32; auto thr_mma_16_16_32 = tiled_mma_16_16_32.get_thread_slice(tidx); typename Kernel_traits::TiledMma_int8 tiled_mma_int8; auto thr_mma_int8 = tiled_mma_int8.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O_16_32_16 tiled_mma_o_16_32_16; auto thr_mma_o_16_32_16 = tiled_mma_o_16_32_16.get_thread_slice(tidx); // typename Kernel_traits::TiledMma_O_int8 tiled_mma_o_int8; // auto thr_mma_o_int8 = tiled_mma_o_int8.get_thread_slice(tidx); #if 1 // 过lds读取q, 由于q是4个warp共用的 typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (threadIdx.x < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #else auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); #endif #if 0 #else auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma_16_16_32); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor tSrK_int8 = thr_mma_int8.partition_fragment_B(gK); #endif auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o_16_32_16); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle_fp8); // Tensor tOrVt_int8 = thr_mma_o_int8.partition_fragment_B(sVtNoSwizzle); // if (thread0()) // { // print("tOsVt "); print(tOsVt), print("\n"); // print("tOrVt "); print(tOrVt), print("\n"); // // print("tOrVt_int8 "); print(tOrVt_int8), print("\n"); // } constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; const auto sk_data = sK.data(); const auto sRow_max_reduce_buffer_data = sRow_max_reduce_buffer.data(); constexpr auto sk_size = size(sK); const auto sP_data = sP.data(); const auto tSsK_data = tSsK.data(); const auto tOsVt_data = tOsVt.data(); const auto gK_data = gK.data(); constexpr static int BUFFER_SIZE = 1; // uint128_t buffer[BUFFER_SIZE]; constexpr short int wait_cnt = 8; { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = n_block % 2 == 1 ? sk_data + sk_size : sk_data; // tSsK.data() = n_block % 2 == 0? tSsK_data + sk_size : tSsK_data; // tOsVt.data() = n_block % 2 ? tOsVt_data + sk_size : tOsVt_data; // if (block0() && tidx < 64) // { // printf("tidx = %d Is_causal = %d n_block = %d addr %p sk %p tSsK = %p \n", tidx, Is_causal, n_block, &(sRow_max_reduce_buffer(0)), (&sK(0, 0)), (&tSsK(0, 0, 0))); // } #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } // __syncthreads(); // if (thread0()) // { // for (int i = 0; i < 64; i++) // { // for (int j = 0; j < 512; j++) // { // printf(" %.2f ", float(sK(i, j))); // } // printf("\n"); // } // } } Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; #if 1 for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { // asm volatile("s_barrier\n\t"); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN); #if 0 #else gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); #endif // asm volatile("s_barrier\n\t"); } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { acc_s(i) *= k_scale; } if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; // We have key_padding_mask so we'll need to Check_inf if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } // __syncthreads(); { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif } } #endif #if 1 for (; n_block >= n_block_min; --n_block) { // asm volatile("s_barrier\n\t"); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - (n_block - 1) * kBlockN); #if 0 #else gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); #endif // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); // asm volatile("s_barrier\n\t"); } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { for (int i = 0; i < size(acc_s); i++) { acc_s(i) *= k_scale; } } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; // sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; // sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; // tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } // __syncthreads(); { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif // tOsVt.data() = (n_block - 1) % 2 ? tOsVt_data + sk_size : tOsVt_data; } } #endif // if (thread0()) // { // printf("NoSplit %d \n", NoSplit); // } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { for (int i = 0; i < size(acc_o); i++) { acc_o(i) *= k_scale; } } if (NoSplit) store_fp8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store_fp8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp1(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, const float k_scale, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; // const float k_scale = *reinterpret_cast(params.k_scale_ptr); // if (thread0()) // { // printf("k_scale %.2f\n", k_scale); // } // constexpr int is_scale_equal_one = true; const int tidx = threadIdx.x; const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64); const int lane_id = tidx % 64; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); // Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); // Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); // Tensor sVtNoSwizzle_fp8 = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle_fp8{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); typename Kernel_traits::TiledMma_16_16_32 tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); Tensor tSrK = thr_mma.partition_fragment_B(sK); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); typename Kernel_traits::TiledMma_16_16_32_for_copy tiled_mma_for_copy; auto thr_mma_for_copy = tiled_mma_for_copy.get_thread_slice(tidx); auto smem_tiled_copy_K_for_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma_for_copy); auto smem_thr_copy_K_for_copy = smem_tiled_copy_K_for_copy.get_thread_slice(tidx); Tensor tSsK_for_copy = smem_thr_copy_K_for_copy.partition_S(sK); // if (thread0()) // { // printf(" tSsK\n "); // print(tSsK); // printf("\n"); // printf(" tSsK_for_copy\n "); // print(tSsK_for_copy); // printf("\n"); // } // if (block0()) // { // printf(" %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", float(tSrQ(0)), // float(tSrQ(1)), // float(tSrQ(2)), // float(tSrQ(3)), // float(tSrQ(4)), // float(tSrQ(5)), // float(tSrQ(6)), // float(tSrQ(7)) // ); // } const int *block_table = params.block_table + bidb * params.block_table_batch_stride; const auto gK_data = gK.data(); int n_block = n_block_max - 1; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); typedef __fp16 __fp16x4_t __attribute__((ext_vector_type(4))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; union Fp8_data{ uint32x2_t data_64; __hip_fp8x4_storage_t fp8_array[2]; }; Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); union bf16_storage{ __fp16x8_t data_128; uint16_t data_array[8]; }; extern __shared__ char shared_memory[]; #if 0 #else for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); Fp8_data data[5]; for (int m = 0; m < 2; m++) { for (int k_idx = 0; k_idx < 18; k_idx+=4) { buffer_load_copy_fp8x2(gK, data[k_idx / 4].data_64, k_idx / 4, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); } uint16_t* k_lds_write_ptr = reinterpret_cast(shared_memory) + (warp_id / 4) * (32 * 64) + (warp_id % 4) * 8 * 64 + (lane_id * 8 + (warp_id % 4) * 16) % (64 * 8); for (int k_idx = 0; k_idx < 16; k_idx+=4) { // union // bf16_storage bf16_data; // __fp16x8_t *lds_ptr = reinterpret_cast<__fp16x8_t*>(&(tSsK(0, 0, k_idx + (warp_id % 4)))); for (int j = 0; j < 8; j += 4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx / 4].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx / 4].fp8_array[j / 4])) + 1); auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); // bf16_data.data_array[j] = rst0.storage; // bf16_data.data_array[j + 1] = rst1.storage; // bf16_data.data_array[j + 2] = rst2.storage; // bf16_data.data_array[j + 3] = rst3.storage; k_lds_write_ptr[j] = rst0.storage; k_lds_write_ptr[j + 1] = rst1.storage; k_lds_write_ptr[j + 2] = rst2.storage; k_lds_write_ptr[j + 3] = rst3.storage; // 这样写无法向量化访存 ds_write_b128 // tSsK(j, 0, k_idx + (warp_id % 4)) = rst0; // tSsK(j + 1, 0, k_idx + (warp_id % 4)) = rst1; // tSsK(j + 2, 0, k_idx + (warp_id % 4)) = rst2; // tSsK(j + 3, 0, k_idx + (warp_id % 4)) = rst3; } // __fp16x8_t *lds_ptr = reinterpret_cast<__fp16x8_t*>(&(tSsK(0, 0, k_idx + (warp_id % 4)))); // *lds_ptr = bf16_data.data_128; k_lds_write_ptr += 32 * 128; } if (warp_id < 4) { int k_idx = 16; // uint16_t* k_lds_ptr = reinterpret_cast(shared_memory) + 32 * 128 * (k_idx / 4) // + (warp_id / 4) * (32 * 64) + (warp_id % 4) * 8 * 64 + (lane_id * 8 + (warp_id % 4) * 16) % (64 * 8); for (int j = 0; j < 8; j += 4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx / 4].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx / 4].fp8_array[j / 4])) + 1); auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); k_lds_write_ptr[j] = rst0.storage; k_lds_write_ptr[j + 1] = rst1.storage; k_lds_write_ptr[j + 2] = rst2.storage; k_lds_write_ptr[j + 3] = rst3.storage; // tSsK(j, 0, k_idx + (warp_id % 4)) = rst0; // tSsK(j + 1, 0, k_idx + (warp_id % 4)) = rst1; // tSsK(j + 2, 0, k_idx + (warp_id % 4)) = rst2; // tSsK(j + 3, 0, k_idx + (warp_id % 4)) = rst3; } } // if (thread0()) // // // if (block0()) // { // uint16_t* k_lds_ptr = reinterpret_cast(shared_memory); // for (int i = 0; i < 32 * 64; i++) // { // if (i % 64 == 0) // printf("\n"); // // for (int j = 0; j < 64; j++) // { // printf(" %.2f ", float(cutlass::bfloat16_t::bitcast(k_lds_ptr[i]))); // } // } // // printf("tidx = %d warp_id %d %p %p \n",tidx, warp_id, &tSsK(0, 0, 0), &tSsK(0, 0, 1)); // } __syncthreads(); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int<32>>{}); clear(acc_s); #if 0 #else int row_ = lane_id % 16; int col_ = lane_id / 16; uint16_t* k_lds_ptr = reinterpret_cast(shared_memory) + (row_ % 4) * (8 * 64 + 16) + (row_ / 4) * (8 * 8) + col_ * 8 + (warp_id / 4) * (4 * 64) + (warp_id / 4) * ( (lane_id == 47 || lane_id == 63) * (- 8 * 64) ); uint16_t* k_lds_ptr1 = reinterpret_cast(shared_memory) + (row_ % 4) * (8 * 64 + 16) + (row_ / 4) * (8 * 8) + col_ * 8 + (warp_id / 4) * (4 * 64) + (warp_id / 4) * ( (lane_id == 45 || lane_id == 61 || lane_id == 14 || lane_id == 30 || lane_id == 46 || lane_id == 62 || lane_id == 15 || lane_id == 31 || lane_id == 47 || lane_id == 63 ) * (- 8 * 64) ) + 4 * 8; for (int k = 0; k < 18; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } #endif // if (block0() && m == 0 && warp_id == 0) // { // printf(" accs %d %.2f %.2f %.2f %.2f\n", threadIdx.x, acc_s(0), acc_s(1), acc_s(2), acc_s(3)); // } Tensor cS = make_identity_tensor(Shape, Int<32>>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) + m * 32 >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i)) + m * 32) > col_limit_right) acc_s(i) = -INFINITY; } } { const bool is_first_masking_step = masking_step == 0 && m == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); Tensor tOrP = convert_layout_acc_Aregs_tp1(tiled_mma, tiled_mma_o, rP, sP); // // __builtin_amdgcn_sched_barrier(0); // for (int k_idx = 0; k_idx < 18; k_idx+=4) // { // buffer_load_copy_fp8(gK, data[k_idx / 4].data_64, k_idx / 4, params.k_row_stride, 1 * 32, seqlen_k - n_block * kBlockN); // } { // for (int k = 0; k < 4; k++) // { #if 1 // } int row = lane_id / 4; int col = lane_id % 4; __fp16* v_lds_ptr_k0_base = reinterpret_cast<__fp16*>(shared_memory) + (warp_id / 4) * 32 + (row / 4) * 8 * 8 + (row % 4) * (8 * 64 + 16) + col * 8; __ds_read_m32x16_row_col_lds<0, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<1, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<2, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<3, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<4, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<5, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<6, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<7, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); int k1_n0 = ((warp_id / 4) == 0 && (lane_id == 62 || lane_id == 63)) * (-8 * 64); int k1_n1 = ((warp_id / 4) == 1 && (lane_id >= 54 && lane_id < 64)) * (-8 * 64); __fp16* v_lds_ptr_k1_base = v_lds_ptr_k0_base + 4 * 64 + k1_n0 + k1_n1; __ds_read_m32x16_row_col_lds<0, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<1, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<2, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<3, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<4, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<5, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<6, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<7, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); #endif } __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier \n\t"); __builtin_amdgcn_sched_barrier(0); } } #endif if (NoSplit) store_fp8_tp1(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store_fp8_tp1(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64); const int lane_id = tidx % 64; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); // Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); typename Kernel_traits::TiledMma_16_16_32 tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); Tensor tSrK = thr_mma.partition_fragment_B(sK); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); typename Kernel_traits::TiledMma_16_16_32_for_copy tiled_mma_for_copy; auto thr_mma_for_copy = tiled_mma_for_copy.get_thread_slice(tidx); auto smem_tiled_copy_K_for_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma_for_copy); auto smem_thr_copy_K_for_copy = smem_tiled_copy_K_for_copy.get_thread_slice(tidx); Tensor tSsK_for_copy = smem_thr_copy_K_for_copy.partition_S(sK); const int *block_table = params.block_table + bidb * params.block_table_batch_stride; const auto gK_data = gK.data(); int n_block = n_block_max - 1; Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); extern __shared__ char shared_memory[]; typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); union bf16_storage{ __fp16x8_t data_128; uint16_t data_array[8]; }; for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); for (int m = 0; m < 2; m++) { // asm volatile(" s_barrier\n\t"); // lds_direct_copy_tp1(gK, sK, 0, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); lds_direct_copy_tp1(gK, sK, 1, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); lds_direct_copy_tp1(gK, sK, 2, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); lds_direct_copy_tp1(gK, sK, 3, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); lds_direct_copy_tp1(gK, sK, 4, params.k_row_stride, m * 32, seqlen_k - n_block * kBlockN); // asm volatile(" s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int<32>>{}); clear(acc_s); int row_ = lane_id % 16; int col_ = lane_id / 16; uint16_t* k_lds_ptr = reinterpret_cast(shared_memory) + (row_ % 4) * (8 * 64 + 16) + (row_ / 4) * (8 * 8) + col_ * 8 + (warp_id / 4) * (4 * 64) + (warp_id / 4) * ( (lane_id == 47 || lane_id == 63) * (- 8 * 64) ); uint16_t* k_lds_ptr1 = reinterpret_cast(shared_memory) + (row_ % 4) * (8 * 64 + 16) + (row_ / 4) * (8 * 8) + col_ * 8 + (warp_id / 4) * (4 * 64) + (warp_id / 4) * ( (lane_id == 45 || lane_id == 61 || lane_id == 14 || lane_id == 30 || lane_id == 46 || lane_id == 62 || lane_id == 15 || lane_id == 31 || lane_id == 47 || lane_id == 63 ) * (- 8 * 64) ) + 4 * 8; // for (int k = 0; k < 1; k+=2) __builtin_amdgcn_sched_barrier(0); asm volatile(" s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int k = 0; k < 4; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f %p \n", threadIdx.x, warp_id, float(tSrK(0, 0, 0)), float(tSrK(1, 0, 0)), float(tSrK(2, 0, 0)), float(tSrK(3, 0, 0)), k_lds_ptr); // } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int k = 4; k < 8; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f %p \n", threadIdx.x, warp_id, float(tSrK(0, 0, 0)), float(tSrK(1, 0, 0)), float(tSrK(2, 0, 0)), float(tSrK(3, 0, 0)), k_lds_ptr); // } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int k = 8; k < 12; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f %p \n", threadIdx.x, warp_id, float(tSrK(0, 0, 0)), float(tSrK(1, 0, 0)), float(tSrK(2, 0, 0)), float(tSrK(3, 0, 0)), k_lds_ptr); // } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int k = 12; k < 16; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f %p \n", threadIdx.x, warp_id, float(tSrK(0, 0, 0)), float(tSrK(1, 0, 0)), float(tSrK(2, 0, 0)), float(tSrK(3, 0, 0)), k_lds_ptr); // } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int k = 16; k < 18; k+=2) { bf16_storage data; data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f %p \n", threadIdx.x, warp_id, float(tSrK(0, 0, 0)), float(tSrK(1, 0, 0)), float(tSrK(2, 0, 0)), float(tSrK(3, 0, 0)), k_lds_ptr); // } cute::gemm(tiled_mma, tSrQ(_, _, k), tSrK(_, _, k), acc_s); data.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_ptr1); for (int i = 0; i < 8; i++) { if constexpr (std::is_same_v) { tSrK(i, 0, k + 1) = cutlass::bfloat16_t::bitcast(data.data_array[i]); } } cute::gemm(tiled_mma, tSrQ(_, _, k+1), tSrK(_, _, k+1), acc_s); k_lds_ptr += 32 * 64; k_lds_ptr1 += 32 * 64; } // if (block0()) // { // printf("tidx = %d warp_id = %d %.2f %.2f %.2f %.2f \n", threadIdx.x, warp_id, acc_s(0), acc_s(1), acc_s(2), acc_s(3) ); // } // asm volatile(" s_barrier\n\t"); Tensor cS = make_identity_tensor(Shape, Int<32>>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) + m * 32 >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i)) + m * 32) > col_limit_right) acc_s(i) = -INFINITY; } } { const bool is_first_masking_step = masking_step == 0 && m == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } // if (block0()) // { // printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3) ); // } Tensor rP = flash::convert_type(acc_s); Tensor tOrP = convert_layout_acc_Aregs_tp1(tiled_mma, tiled_mma_o, rP, sP); { // for (int k = 0; k < 4; k++) // { #if 1 // } int row = lane_id / 4; int col = lane_id % 4; __fp16* v_lds_ptr_k0_base = reinterpret_cast<__fp16*>(shared_memory) + (warp_id / 4) * 32 + (row / 4) * 8 * 8 + (row % 4) * (8 * 64 + 16) + col * 8; __ds_read_m32x16_row_col_lds<0, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<1, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<2, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<3, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<4, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<5, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<6, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<7, 0>(v_lds_ptr_k0_base, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); int k1_n0 = ((warp_id / 4) == 0 && (lane_id == 62 || lane_id == 63)) * (-8 * 64); int k1_n1 = ((warp_id / 4) == 1 && (lane_id >= 54 && lane_id < 64)) * (-8 * 64); __fp16* v_lds_ptr_k1_base = v_lds_ptr_k0_base + 4 * 64 + k1_n0 + k1_n1; __ds_read_m32x16_row_col_lds<0, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<1, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<2, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<3, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<4, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<5, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<6, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); __ds_read_m32x16_row_col_lds<7, 1>(v_lds_ptr_k1_base, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); #endif } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } } if (NoSplit) store_fp8_tp1(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store_fp8_tp1(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } // #if defined(__gfx936__) // #if 1 template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); // if (thread0()) // { // printf("sv sp srow = %p %p %p \n", sV.data().get(), sP.data().get(), sRow_max_reduce_buffer.data().get()); // } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); #if 1 // lds_direct_copy(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // lds_direct_copy(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM); // // if (thread0()) // // { // // for (int i = 0; i < 16; i++) // // { // // for (int j = 0; j < 576; j++) // // { // // printf(" %.2f ", float(sQ(i, j))); // // } // // printf("\n"); // // } // // } // auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); // asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); // asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16)); // cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17)); // __syncthreads(); //过lds读取q, 由于q是4个warp共用的 typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (threadIdx.x < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #else auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); #endif #if 0 auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = smem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tSgK))); #else typename Kernel_traits::GmemTiledCopyK gmem_tiled_copy_K; auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx); Tensor tKgK = gmem_thr_copy_K.partition_S(gK); Tensor tKsK = gmem_thr_copy_K.partition_D(sK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = gmem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tKgK))); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); // if (thread0()) // { // print("tSgK\n"); print(tSgK); print("\n"); // print("tKgK\n"); print(tKgK); print("\n"); // } Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK); Tensor tKpK_smem = make_tensor(make_shape(size<2>(tSgK))); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); #endif typename Kernel_traits::GmemTiledCopyV gmem_tiled_copy_V; auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx); Tensor tVgV = gmem_thr_copy_V.partition_S(gV); Tensor tVsV = gmem_thr_copy_V.partition_D(sV); // if (0 || thread(64)) // { // print("tksk "); print(tKsK); print("\n"); // print("tVsV "); print(tVsV); print("\n"); // } Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV))); Tensor tVcV = gmem_thr_copy_V.partition_S(cV); Tensor tVpV = make_tensor(make_shape(size<2>(tVgV))); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; constexpr static int k0_lds_loops = 15; constexpr static int k0_loops = size<2>(tSrK_smem); constexpr static int k1_loops = size<2>(tOrVt); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; constexpr static int STAGE = 15; Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); #if 1 #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); // 这个也做过循环2类似的修改,但是性能不如现在的好,所以保持不变 int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < STAGE; i++) { lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 3; uint128_t buffer[BUFFER_SIZE]; buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[0], tSrK_smem, 15); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[1], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[2], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } // We have key_padding_mask so we'll need to Check_inf if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); lds_direct_copy(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); // asm_ds_write(buffer[0], tVsV, 15); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); #pragma unroll for (int i = 0; i < k1_loops; i++) { cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); } // asm volatile("s_barrier\n\t"); } #endif #if 1 #pragma unroll for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < 16; i++) { lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 2; uint128_t buffer[BUFFER_SIZE]; // buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[0], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[1], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); __builtin_amdgcn_sched_barrier(0); __ds_read_m32x16_row_col<3, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<3, 3>(tOsVt, tOrVt_copy_view); __builtin_amdgcn_sched_barrier(0); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[0], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); buffer_to_tensor(buffer[1], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); // We have key_padding_mask so we'll need to Check_inf softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view); __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o); // asm volatile("s_barrier\n\t"); } #endif if (NoSplit) store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } // #elif 0 // #elif defined(__gfx928__) template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); #if 1 // 过lds读取q, 由于q是4个warp共用的 typename Kernel_traits::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (threadIdx.x < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #else auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); #endif #if 0 auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = smem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tSgK))); #else typename Kernel_traits::GmemTiledCopyK gmem_tiled_copy_K; auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx); Tensor tKgK = gmem_thr_copy_K.partition_S(gK); Tensor tKsK = gmem_thr_copy_K.partition_D(sK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = gmem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tKgK))); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK); Tensor tKpK_smem = make_tensor(make_shape(size<2>(tSgK))); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); #endif typename Kernel_traits::GmemTiledCopyV gmem_tiled_copy_V; auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx); Tensor tVgV = gmem_thr_copy_V.partition_S(gV); Tensor tVsV = gmem_thr_copy_V.partition_D(sV); Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV))); Tensor tVcV = gmem_thr_copy_V.partition_S(cV); Tensor tVpV = make_tensor(make_shape(size<2>(tVgV))); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; constexpr static int k0_lds_loops = 16; constexpr static int k0_loops = size<2>(tSrK_smem); constexpr static int k1_loops = size<2>(tOrVt); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; int cur_block_table = block_table[n_block]; index_t offset_k = cur_block_table * params.k_batch_stride; constexpr static int BUFFER_SIZE = 4; uint128_t buffer[BUFFER_SIZE]; gK.data() = gK.data() + (offset_k); buffer_load_copy(gK, buffer[0], 0, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[1], 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[2], 2, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); #if 1 #pragma unroll for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // 计算0~11 #if 1 #pragma unroll for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) { // asm volatile("s_waitcnt vmcnt(3) \n\t \n\t"); asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i)); buffer_load_copy(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s); // asm volatile("s_barrier\n\t"); } // asm volatile("s_barrier\n\t"); #endif #if 0 #else // 计算 13-15 const int k_idx = k0_lds_loops - BUFFER_SIZE + 1; asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s); asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s); // asm volatile("s_barrier\n\t"); // 读取16-17 buffer_load_copy(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_to_tensor(buffer[1], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); buffer_to_tensor(buffer[2], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); #endif const bool is_masking_step = masking_step > 0; const bool is_first_masking_step = masking_step == n_masking_steps; if (is_masking_step) { Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } } // We have key_padding_mask so we'll need to Check_inf is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : is_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); Tensor tOrP = convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); #if 1 // 第15块已经读取到了buffer[3]中 asm_ds_write(buffer[3], tVsV, 15); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); #endif gK.data() = gK.data() + (-offset_k); if (n_block > n_block_min) { cur_block_table = block_table[n_block - 1]; offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); buffer_load_copy(gK, buffer[0], 0, params.k_row_stride, offset_k); buffer_load_copy(gK, buffer[1], 1, params.k_row_stride, offset_k); buffer_load_copy(gK, buffer[2], 2, params.k_row_stride, offset_k); } Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); #pragma unroll for (int i = 0; i < k1_loops; i++) { cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } #endif if (NoSplit) store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); else store(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax); } // #endif template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); const float k_scale = KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 || KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 ? *reinterpret_cast(params.k_scale_ptr) : 1.0f; const bool is_scale_equal_one = (k_scale == 1.0); #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { #if defined(__gfx936__) || defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #elif defined(__gfx928__) { flash::compute_attn_1rowblock_splitkv_mla_gfx928(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #endif } else if (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 || KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2) { #if defined(__gfx936__) || defined(__gfx938__) is_scale_equal_one ? flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, k_scale, shared_storage): flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, k_scale, shared_storage); #endif } } } template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel_tp1(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); const float k_scale = KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 || KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 ? *reinterpret_cast(params.k_scale_ptr) : 1.0f; const bool is_scale_equal_one = (k_scale == 1.0); #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { #if defined(__gfx936__) || defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_gfx936_tp1(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #elif defined(__gfx928__) { // flash::compute_attn_1rowblock_splitkv_mla_gfx928(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #endif } else if (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2) { #if defined(__gfx936__) || defined(__gfx938__) flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp1(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, k_scale, shared_storage); #endif } } } template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel_nope_pe(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); const float k_scale = KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 || KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 ? *reinterpret_cast(params.k_scale_ptr) : 1.0f; const bool is_scale_equal_one = (k_scale == 1.0); #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { #if defined(__gfx936__) || defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_nope_pe_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #elif defined(__gfx928__) { // flash::compute_attn_1rowblock_splitkv_mla_gfx928(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); } #endif } else if (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2) { #if defined(__gfx936__) || defined(__gfx938__) flash::compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, k_scale, shared_storage); // flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx936(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, k_scale, shared_storage); #endif } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void __launch_bounds__(kNThreads_, 1) flash_fwd_splitkv_mla_combine_kernel(const Flash_fwd_mla_params params) { constexpr int kNThreads = kNThreads_; constexpr int Warp_Size = 64; const int tidx = threadIdx.x; const int bidx = blockIdx.x; const int hs = params.h * params.seqlen_q; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs; const int split_offset = __ldg(params.num_splits_ptr + batch_idx); const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; //FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); if (actual_num_splits == 1) return; __shared__ ElementAccum sLseScale[kMaxSplits]; const index_t row_offset_lseaccum = split_offset * hs + hs_idx; const index_t row_offset_lse = bidx; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), Shape>{}, make_stride(hs)); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape<_1>{}, Stride<_1>{}); const int warp_idx = tidx / Warp_Size; if (warp_idx == 0) { constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, Warp_Size); float local_lse[kNLsePerThread]; for (int i = 0; i < kNLsePerThread; ++i) { const int split = i * Warp_Size + tidx; local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; } float max_lse = -INFINITY; for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); for (int offset = 32; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor(max_lse, offset, 64)); max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf float sum_lse = 0; for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + __builtin_amdgcn_exp2f((local_lse[i] - max_lse) * 1.4426950408889634); for (int offset = 32; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor(sum_lse, offset, 64); float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; if (tidx == 0) gLSE(0) = global_lse; for (int i = 0; i < kNLsePerThread; ++i) { const int split = i * 64 + tidx; if (split < actual_num_splits) sLseScale[split] = __builtin_amdgcn_exp2f((local_lse[i] - global_lse) * 1.4426950408889634); } } __syncthreads(); static_assert(kHeadDimV % kNThreads == 0); #if 1 constexpr int Elements = kHeadDimV / kNThreads; const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape<_1, Int>{}, Stride, _1>{}); using GmemLayoutAtomOaccum = Layout>, Stride, _1>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom{}, // Layout>>{}, GmemLayoutAtomOaccum{}, Layout>>{})); #else constexpr int Elements = 1; const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape>{}, Stride<_1>{}); using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom{}, Layout>>{}, Layout>>{})); #endif GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); // if (thread0()) // { // print(gmem_thr_copy_Oaccum); print("\n"); // // print(gOaccum); // } #if 1 Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); Tensor tOrO = make_tensor(shape(tOgOaccum)); clear(tOrO); for (int split = 0; split < actual_num_splits; ++split) { cute::copy(tOgOaccum, tOrOaccum); ElementAccum lse_scale = sLseScale[split]; for (int i = 0; i < size(tOrO); ++i) { tOrO(i) += lse_scale * tOrOaccum(i); } tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; } // if (block(0)) // for (int i = 0; i < size(tOrO); i++) // { // tOrO(i) = 1.0; // } Tensor rO = flash::convert_type(tOrO); // if (thread(0)) // { // printf("tidx %d %d %.4f %.4f %.4f %.4f \n", tidx, (int)size(tOrO), float(rO(0)), float(rO(1)), float(rO(2)), float(rO(3))); // } const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); cute::copy(rO, gO); #endif } } // namespace flash //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // auto kernel = &flash::flash_fwd_splitkv_mla_kernel_tp1; auto kernel = &flash::flash_fwd_splitkv_mla_kernel; constexpr size_t smem_size = 65536; // printf("smem_size = %d\n", smem_size); // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_splitkv_fwd_mla_tp1(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel_tp1; // auto kernel = &flash::flash_fwd_splitkv_mla_kernel; constexpr size_t smem_size = 65536; // printf("smem_size = %d\n", smem_size); // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_splitkv_fwd_mla_q_nope_pe(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel_nope_pe; constexpr size_t smem_size = 65536; // printf("smem_size = %d\n", smem_size); // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::Element_O, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } #endif template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv_cache_dtype, cudaStream_t stream, bool is_q_nope_pe) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV if (is_q_nope_pe) { if (kv_cache_dtype == "auto") { using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 16, 64, 4, T, 512>; run_flash_splitkv_fwd_mla_q_nope_pe, Fp8KVCacheDataType::kAuto>(params, stream); } else if (kv_cache_dtype == "fp8_e5m2") { using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, T, 512, T>; run_flash_splitkv_fwd_mla_q_nope_pe, Fp8KVCacheDataType::kFp8E5M2>(params, stream); } else { printf("is_q_nope_pe = %d Unsupported kv cache dtype \n", is_q_nope_pe); exit(-1); } return; } if (kv_cache_dtype == "auto") { // printf(" seqlen_q %d \n", params.seqlen_q); if (params.seqlen_q > 32) { using Kernel_traits = Flash_fwd_kernel_traits_mla_tp1<576, 64, 64, 8, T, 512>; run_flash_splitkv_fwd_mla_tp1, Fp8KVCacheDataType::kAuto>(params, stream); } else { using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 16, 64, 4, T, 512>; run_flash_splitkv_fwd_mla, Fp8KVCacheDataType::kAuto>(params, stream); } } else if (kv_cache_dtype == "fp8_e4m3") { using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, T, 512>; run_flash_splitkv_fwd_mla, Fp8KVCacheDataType::kFp8E4M3>(params, stream); } else if (kv_cache_dtype == "fp8_e5m2") { if (params.seqlen_q > 32) { using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8_TP1<576, 64, 64, 8, T, 512>; run_flash_splitkv_fwd_mla_tp1, Fp8KVCacheDataType::kFp8E5M2>(params, stream); } else { using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, T, 512>; run_flash_splitkv_fwd_mla, Fp8KVCacheDataType::kFp8E5M2>(params, stream); } } else { printf("Unsupported kv cache dtype \n"); exit(-1); } }