#include "flash_fwd_mla_kernel.h" template struct Flash_fwd_kernel_traits_mla_qkvfp8 { using Element = elem_type; using ElementO = elem_type_o; using ElementQ = elem_type_q; using ElementAccum = float; using index_t = int64_t; static constexpr bool IS_WITH_CAT = is_with_cat; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 64 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 64 == 0); static_assert(kHeadDimV <= kHeadDim); // static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = 3; //gloalload //using MMA_Atom_Arch = MMA_Atom; //bufferload using MMA_Atom_Arch = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>;// using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // using SmemLayoutRow = Layout>, Stride<_4, _1>>; // 128*4=512 using SmemLayoutRow = Layout, Stride<_1>>; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutAtomK = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<8 * 64>>{})); using SmemLayoutAtomK_temp = Layout, Int<64>>, Stride<_64, _1>>; using SmemLayoutK_temp = decltype(tile_to_shape( SmemLayoutAtomK_temp{}, Shape, Int<7*64>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); //bufferload using SmemLayoutAtomQ = Layout, Int<64>>, Stride, _1>>; using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); // //gloalload // using SmemLayoutAtomQ = decltype(composition( // Swizzle{}, // Layout, Int<64>>, Stride, _1>>{}));//8*64 // using SmemLayoutQ = decltype(tile_to_shape( // SmemLayoutAtomQ{}, // Shape, Int>{})); using SmemLayoutAtomO = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using GmemLayoutAtomO = Layout, Stride< _16, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _16, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; template struct Flash_fwd_kernel_traits_mla_qkvfp8_TP1 { using Element = elem_type; using ElementO = elem_type_o; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 64 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 64 == 0); static_assert(kHeadDimV <= kHeadDim); // static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = 3; //gloalload using MMA_Atom_Arch = MMA_Atom; //bufferload // using MMA_Atom_Arch = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>;// using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // using SmemLayoutRow = Layout>, Stride<_4, _1>>; using SmemLayoutRow = Layout, Stride<_1>>; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutAtomK = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<10 * 64>>{})); using SmemLayoutAtomK_temp = Layout, Int<64>>, Stride<_64, _1>>; using SmemLayoutK_temp = decltype(tile_to_shape( SmemLayoutAtomK_temp{}, Shape, Int<10*64>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); //bufferload using SmemLayoutAtomQ = Layout, Int<64>>, Stride, _1>>; using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); // //gloalload // using SmemLayoutAtomQ = decltype(composition( // Swizzle{}, // Layout, Int<64>>, Stride, _1>>{}));//8*64 // using SmemLayoutQ = decltype(tile_to_shape( // SmemLayoutAtomQ{}, // Shape, Int>{})); using SmemLayoutAtomO = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using GmemLayoutAtomO = Layout, Stride< _16, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _16, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; template struct Flash_fwd_kernel_traits_mla_qkvfp8_TP4 { using Element = elem_type; using ElementO = elem_type_o; using ElementAccum = float; using index_t = int64_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 64; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 64 == 0); static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 64 == 0); static_assert(kHeadDimV <= kHeadDim); // static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = 3; //gloalload using MMA_Atom_Arch = MMA_Atom; //bufferload // using MMA_Atom_Arch = MMA_Atom; using MMA_Atom_Arch_16x32 = MMA_Atom; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>;// using TiledMma_O = TiledMMA< MMA_Atom_Arch_16x32, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; // using SmemLayoutRow = Layout>, Stride<_4, _1>>; using SmemLayoutRow = Layout, Stride<_1>>; // using SmemLayoutAtomK = Layout, Int<512>>, Stride<_512, _1>>; using SmemLayoutAtomK = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, Shape, Int<10 * 64>>{})); using SmemLayoutAtomK_temp = Layout, Int<64>>, Stride<_64, _1>>; using SmemLayoutK_temp = decltype(tile_to_shape( SmemLayoutAtomK_temp{}, Shape, Int<10*64>>{})); using SmemLayoutAtomV = SmemLayoutAtomK; using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, Shape, Int>{})); using SmemLayoutAtomP = Layout>, Stride>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape>{})); using SmemLayoutVtransposed = decltype( composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); //bufferload using SmemLayoutAtomQ = Layout, Int<64>>, Stride, _1>>; using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); // //gloalload // using SmemLayoutAtomQ = decltype(composition( // Swizzle{}, // Layout, Int<64>>, Stride, _1>>{}));//8*64 // using SmemLayoutQ = decltype(tile_to_shape( // SmemLayoutAtomQ{}, // Shape, Int>{})); using SmemLayoutAtomO = decltype(composition( Swizzle{}, Layout, Int<64>>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; using GmemLayoutAtomO = Layout, Stride< _16, _1>>; using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomO{}, Layout>{})); using GmemLayoutAtomQ = Layout, Stride< _8, _1>>; using GmemTiledCopyQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomQ{}, Layout>{})); using GmemLayoutAtomOaccum = Layout, Stride< _16, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); }; namespace flash { using namespace cute; template struct SharedStorageMLAFloat8 { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_o; }; struct { cute::array_aligned> smem_q; }; }; }; template struct SharedStorageMLAFloat8_TP1 { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_o; }; struct { cute::array_aligned> smem_q; }; }; }; template struct SharedStorageMLAFloat8_TP4 { union { struct { cute::array_aligned> smem_v; // Double buffer }; struct { cute::array_aligned> smem_temp; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_row_sum; cute::array_aligned> smem_row_max; }; struct { cute::array_aligned> smem_o; }; struct { cute::array_aligned> smem_q; }; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, SharedStorage &shared_storage, AccO tOrO, Softmax softmax,float descale_k, float scale_softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); // Epilogue const int split_offset = __ldg(params.num_splits_ptr + bidb); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor lse = softmax.template normalize_softmax_lse_fp8(tOrO, sRow_sum_reduce_buffer, scale_softmax, descale_k); using ElementO = std::conditional_t; Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning using SmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor rO = flash::convert_type(tOrO); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) // __syncthreads(); cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); using GmemTiledCopyO = std::conditional_t; GmemTiledCopyO gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); // if (tidx >= kNThreadsS) { return; } Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) // Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); // CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM ); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr static int IS_WITH_CAT = Kernel_traits::IS_WITH_CAT; using Element = typename Kernel_traits::Element; using ElementQ = typename Kernel_traits::ElementQ; using index_t = typename Kernel_traits::index_t; extern __shared__ char shared_memory[]; const int tidx = threadIdx.x; const int lane_idx = tidx % 64; const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*576 Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*512 Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576 Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512 Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512 Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64 Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64 typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64 typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32 auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); union Fp8_storage { intx4_t data; intx2_t p[2]; int bf16[4]; }; #if 0 #else if constexpr (!IS_WITH_CAT) { const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{}));//16*576 lds_direct_copy_qkvfp8(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); __syncthreads(); } else { const index_t row_offset_q_nope = bidb * params.q_nope_batch_stride + m_block * kBlockM * params.q_nope_row_stride + bidh * params.q_nope_head_stride; Tensor gQ_nope = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_nope_ptr) + row_offset_q_nope), Shape, Int<512>>{}, make_stride(params.q_nope_row_stride, _1{})); const index_t row_offset_q_pe = bidb * params.q_pe_batch_stride + m_block * kBlockM * params.q_pe_row_stride + bidh * params.q_pe_head_stride; Tensor gQ_pe = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_pe_ptr) + row_offset_q_pe), Shape, Int<64>>{}, make_stride(params.q_pe_row_stride, _1{})); if constexpr (std::is_same_v) { ElementQ* s_q = reinterpret_cast(shared_memory); auto lds_direct_copy_q = [&](const int k_idx, const int offset_k) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; if (k_idx == 2) { *(uint64_t*)&glob_ptr = reinterpret_cast(gQ_pe.data().get()); } else { *(uint64_t*)&glob_ptr = reinterpret_cast(gQ_nope.data().get()); } uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = 64 * 8 * 2; constexpr int bytes_per_block = bytes_per_warp * 4; const int row_idx = lane_idx % 16; const int col_idx = lane_idx / 16; const int row_offset = row_idx; int col_offset; int offset_v; if (k_idx == 2) { col_offset = warp_idx * 8 + col_idx * 16; offset_v = (row_offset * params.q_pe_row_stride + col_offset) * 2; } else { col_offset = k_idx * 256 + warp_idx * 64 + col_idx * 16 + offset_k * 8; offset_v = (row_offset * params.q_nope_row_stride + col_offset) * 2; } if (k_idx == 2 && warp_idx >= 2) { offset_v = -1; } if (m_block * kBlockM + row_idx >= params.seqlen_q) { offset_v = -1; } const int offset_s = 0; int ldsAddrPerWave = reinterpret_cast(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block + offset_k * 3 * bytes_per_block; asm volatile( "s_mov_b32 m0, %1 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); }; lds_direct_copy_q(0, 0); lds_direct_copy_q(1, 0); lds_direct_copy_q(0, 1); lds_direct_copy_q(1, 1); lds_direct_copy_q(2, 0); ElementQ* s_q_read_ptr = s_q + lane_idx * 8; Fp8_storage bf16_data; asm volatile("s_waitcnt vmcnt(4) \n s_barrier"); float fp32[8]; union Fp8_temp{ int32_t data; uint8_t p_fp8[4]; }; for (int k = 0; k < 4; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 0; i < 8; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } asm volatile("s_waitcnt vmcnt(3) \n s_barrier"); for (int k = 4; k < 8; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 0; i < 8; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } asm volatile("s_waitcnt vmcnt(2) \n s_barrier"); s_q_read_ptr = s_q + lane_idx * 8 + 3 * 4 * 16 * 4 * 8; for (int k = 0; k < 4; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 8; i < 16; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } asm volatile("s_waitcnt vmcnt(1) \n s_barrier"); for (int k = 4; k < 8; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 8; i < 16; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } asm volatile("s_waitcnt vmcnt(0) \n s_barrier"); s_q_read_ptr = s_q + lane_idx * 8 + 2 * 4 * 16 * 4 * 8; for (int k = 8; k < 9; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 0; i < 8; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } for (int k = 8; k < 9; k++) { bf16_data.data = *reinterpret_cast(s_q_read_ptr); for (int i = 0; i < 4; i++) { fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false); fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true); } for (int i = 8; i < 16; i+=4) { Fp8_temp fp8_tmp; fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false); fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true); tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0]; tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1]; tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2]; tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3]; } s_q_read_ptr += 16 * 32; } __syncthreads(); } else { lds_direct_copy_qkvfp8(gQ_nope, sQ, 0, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8(gQ_nope, sQ, 1, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_pe(gQ_pe, sQ, 2, params.q_pe_head_stride, params.seqlen_q - m_block * kBlockM); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); __syncthreads(); } } // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); #endif auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ // Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; constexpr static int k1_loops = size<2>(tOrVt); // Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // clear(acc_o); flash::Softmax<1> softmax; constexpr static int STAGE = 8; #if 1 v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1; c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f; c0_1.x = 0.0f; c0_1.y = 0.0f; c0_1.z = 0.0f; c0_1.w = 0.0f; c1_0.x = 0.0f; c1_0.y = 0.0f; c1_0.z = 0.0f; c1_0.w = 0.0f; c1_1.x = 0.0f; c1_1.y = 0.0f; c1_1.z = 0.0f; c1_1.w = 0.0f; c2_0.x = 0.0f; c2_0.y = 0.0f; c2_0.z = 0.0f; c2_0.w = 0.0f; c2_1.x = 0.0f; c2_1.y = 0.0f; c2_1.z = 0.0f; c2_1.w = 0.0f; c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f; c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f; // #pragma unroll for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); clear(acc_s); // asm volatile("s_barrier\n\t"); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #if 1 lds_direct_copy_qkvfp8(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN); constexpr static int BUFFER_SIZE = 1; uint128_t buffer[BUFFER_SIZE]; buffer_load_copy_qkvfp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s); asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s); asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2)); cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s); asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3)); cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4)); cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5)); cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6)); cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7)); Fp8_storage v3_0, v3_1; __ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data); __ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data); cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s); asm volatile("s_waitcnt vmcnt(0) \n\t"); buffer_to_tensor(buffer[0], tSrK, 8); cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); #else #endif gK.data() = gK.data() + (-offset_k); Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } // We have key_padding_mask so we'll need to Check_inf // if constexpr (n_masking_steps == 1) // { // softmax.template softmax_rescale_o_fp8(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1); // } // else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o_fp8(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1) : softmax.template softmax_rescale_o_fp8(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1); } // Tensor rP = flash::convert_type(acc_s); Fp8_storage data_fp8; // convert_layout_acc_Aregs_fp8(tiled_mma, tiled_mma_o, rP, sP, data_fp8.data); { int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; int32_t result; result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0), acc_s(1), result, false); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2), acc_s(3), result, true); int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0])); *lds_ptr = result; __syncthreads(); data_fp8.data = *reinterpret_cast(&(sP[tid * 16])); } { Fp8_storage v0_0, v0_1; Fp8_storage v1_0, v1_1; Fp8_storage v2_0, v2_1; __builtin_amdgcn_sched_barrier(0); __ds_read_m32x32_row_col_rrow<0, 0, 0>(tOsVt, v0_0.data); __ds_read_m32x32_row_col_rrow<1, 0, 1>(tOsVt, v1_0.data); __ds_read_m32x32_row_col_rrow<2, 0, 2>(tOsVt, v2_0.data); c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[0], c3_0, true, false); c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[1], c3_1, true, false); c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[0], c3_0, true, false); c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[1], c3_1, true, false); c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[0], c0_0, true, false); c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[1], c0_1, true, false); c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[0], c1_0, true, false); c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[1], c1_1, true, false); c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[0], c2_0, true, false); c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[1], c2_1, true, false); __ds_read_m32x32_row_col_rrow<0, 1, 0>(tOsVt, v0_1.data); __ds_read_m32x32_row_col_rrow<1, 1, 1>(tOsVt, v1_1.data); __ds_read_m32x32_row_col_rrow<2, 1, 2>(tOsVt, v2_1.data); c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[0], c0_0, true, false); c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[1], c0_1, true, false); c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[0], c1_0, true, false); c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[1], c1_1, true, false); c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[0], c2_0, true, false); c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[1], c2_1, true, false); __builtin_amdgcn_sched_barrier(0); } } #endif Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w; acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w; acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w; acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w; acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w; acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w; acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w; acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w; if (NoSplit) store_float8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax); else store_float8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax); } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP1(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) { if (n_block_max <= n_block_min) { return; } constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64); const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{}));//16*576 const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*576 const auto gK_data = gK.data(); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*512 Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576 Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512 Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512 Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64 Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64 Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64 typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32 union Fp8_storage { // uint32x4_t val; intx4_t data; intx2_t p[2]; int32_t fp8_array[4]; }; Fp8_storage q_r[9]; #if 1 auto gQ_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.q_row_stride; const int q_zero_pad = std::min(std::max(m_block * kBlockM + ((warp_id) % 4 + 1) * 16 - params.seqlen_q, 0), 16); uint32x4_t gQ_rscr = make_rscr((unsigned char*)(gQ.data().get() + gQ_offset), params.q_row_stride, q_zero_pad); auto q_lds_addr = reinterpret_cast(sQ.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64) | 0x80000000; if (m_block * kBlockM + ((warp_id) % 4) * 16 < params.seqlen_q) { __builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 0, 1, 1, 0, 0); q_lds_addr += 64*128; __builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 128, 1, 1, 0, 0); q_lds_addr += 64*128; __builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 256, 1, 1, 0, 0); q_lds_addr += 64*128; __builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 256+128, 1, 1, 0, 0); q_lds_addr += 64*128; if (warp_id < 4) { __builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 512, 1, 1, 0, 0); } else { lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 4); } } else { lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 0); lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 1); lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 2); lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 3); lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 4); } auto q_lds_read_ptr = sQ.data().get() + (warp_id % 4) * 16 * 64; asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); q_r[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 0, 3, 1, 0); // q_lds_read_ptr += 64 * 64; q_r[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 64*64, 3, 1, 0); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); // q_lds_read_ptr += 64 * 64; q_r[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 2*64*64, 3, 1, 0); // q_lds_read_ptr += 64 * 64; q_r[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 3*64*64, 3, 1, 0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // q_lds_read_ptr += 64 * 64; q_r[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 4*64*64, 3, 1, 0); // q_lds_read_ptr += 64 * 64; q_r[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 5*64*64, 3, 1, 0); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); // q_lds_read_ptr += 64 * 64; q_r[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 6*64*64, 3, 1, 0); // q_lds_read_ptr += 64 * 64; q_r[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 7*64*64, 3, 1, 0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // q_lds_read_ptr += 64 * 64; q_r[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 8*64*64, 3, 1, 0); __syncthreads(); #endif auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ // Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; union Val { intx2_t val_to_mmac; int32_t data[2]; }; // Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // clear(acc_o); flash::Softmax<1> softmax; v4f acco_f32[16]; for (int i = 0; i < 16; i++) { acco_f32[i].x = 0.0f; acco_f32[i].y = 0.0f; acco_f32[i].z = 0.0f; acco_f32[i].w = 0.0f; } constexpr static int STAGE = 8; extern __shared__ char shared_memory[]; struct IsMaskBlock {}; struct IsFirstMaskBlock {}; struct IsNoMaskBlock {}; struct IsLastBlock {}; int lane_id = tidx % 64; int row = lane_id / 4; int col = lane_id % 4; col = (col + (row / 2) % 4) % 4; const auto lds_offset = row * 64 + col * 16 + (warp_id / 4) * 64 * 64; uint8_t* kv_lds_write_ptr_base = reinterpret_cast(shared_memory) + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64 + row * 64 + col * 16; Fp8_storage kv_data[5]; { int cur_block_table; // const int *cur_block_table_ptr; cur_block_table = block_table[n_block]; index_t offset_k; //gK.data() = gK_data + (offset_k); // cur_block_table_ptr = block_table + n_block; // asm volatile("s_load_dword %1, %0, 0x0\n\t" // "s_waitcnt lgkmcnt(0)\n\t": // "+s"(cur_block_table_ptr), // "=s"(cur_block_table)); offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); // buffer_load_copy_fp8_tp1(gK, kv_data[0].data, params.k_row_stride, seqlen_k - n_block * kBlockN); // buffer_load_copy_fp8_tp1(gK, kv_data[1].data, params.k_row_stride, seqlen_k - n_block * kBlockN); // buffer_load_copy_fp8_tp1(gK, kv_data[2].data, params.k_row_stride, seqlen_k - n_block * kBlockN); // buffer_load_copy_fp8_tp1(gK, kv_data[3].data, params.k_row_stride, seqlen_k - n_block * kBlockN); // buffer_load_copy_fp8_tp1(gK, kv_data[4].data, params.k_row_stride, seqlen_k - n_block * kBlockN); // uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base; // // for (int i = 0; i < ) // *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[0].data; // kv_lds_write_ptr += 64 * 128; // *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[1].data; // kv_lds_write_ptr += 64 * 128; // *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[2].data; // kv_lds_write_ptr += 64 * 128; // *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[3].data; // kv_lds_write_ptr += 64 * 128; // *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[4].data; // kv_lds_write_ptr += 64 * 128; // gK.data() = gK.data() + (offset_k); auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride; // auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride; // const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16); const int k_zero_pad = std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0); uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad); auto k_lds_addr = reinterpret_cast(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64); if (n_block * kBlockN + ((warp_id) % 4) * 16 < seqlen_k) { k_lds_addr |= 0x80000000; __builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0); k_lds_addr += 64 * 128; __builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 128, 1, 1, 0, 0); k_lds_addr += 64 * 128; __builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256, 1, 1, 0, 0); k_lds_addr += 64 * 128; __builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256+128, 1, 1, 0, 0); k_lds_addr += 64 * 128; if (warp_id < 4) { __builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 512, 1, 1, 0, 0); } else { lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4); } } else { lds_direct_copy_qkvfp8_zero_lds(gK, sK, 0); lds_direct_copy_qkvfp8_zero_lds(gK, sK, 1); lds_direct_copy_qkvfp8_zero_lds(gK, sK, 2); lds_direct_copy_qkvfp8_zero_lds(gK, sK, 3); lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4); } } auto process_one_block = [&] (int block_idx, auto is_mask_block_t) { static constexpr bool IS_MASK_BLOCK = std::is_same_v; static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v; static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v; static constexpr bool IS_LAST_BLOCK = std::is_same_v; v4f accs_f32[2]; for (int i = 0; i < 2; i++) { accs_f32[i].x = 0.0f; accs_f32[i].y = 0.0f; accs_f32[i].z = 0.0f; accs_f32[i].w = 0.0f; } __syncthreads(); auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64; constexpr static int k_read_lds_offset = 32 * 64; // Fp8_storage data[9]; #if 0 Fp8_storage k_data[9]; __builtin_amdgcn_sched_barrier(0); k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096, 3, 1, 0); k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096, 3, 1, 0); k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096, 3, 1, 0); k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096, 3, 1, 0); k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096, 3, 1, 0); k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096, 3, 1, 0); k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096, 3, 1, 0); k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096, 3, 1, 0); k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096, 3, 1, 0); #pragma unroll for (int i = 0; i < 9; i++) { accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[0], true, false); } k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096 + k_read_lds_offset, 3, 1, 0); k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096 + k_read_lds_offset, 3, 1, 0); #pragma unroll for (int i = 0; i < 9; i++) { accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[1], true, false); } __builtin_amdgcn_sched_barrier(0); #else { constexpr static int k_idx = 0; // k_lds_read_ptr += k_idx * 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 1; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 2; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 3; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 4; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 5; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 6; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 7; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } { constexpr static int k_idx = 8; // k_lds_read_ptr += 64 * 64; Fp8_storage k_data; k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false); k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false); } #endif Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w; acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w; if constexpr (!IS_NO_MASK_BLOCK) { Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } } softmax.template softmax_rescale_o_fp8_tp1(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32); Fp8_storage p_fp8; { __builtin_amdgcn_sched_barrier(0); int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; int32_t result; result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true); // int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64])); // *lds_ptr = result; int32_t result1; result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false); result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true); // lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8])); // *lds_ptr = result1; int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64])); *lds_ptr = result; int32_t* lds_ptr1 = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64 + 8])); *lds_ptr1 = result1; __syncthreads(); p_fp8.data = *reinterpret_cast(&(sP[tid * 16 + (warp_id % 4) * 16 * 64])); __builtin_amdgcn_sched_barrier(0); } if (block_idx > n_block_min) { int cur_block_table; const int *cur_block_table_ptr; cur_block_table = block_table[block_idx - 1]; index_t offset_k; // cur_block_table_ptr = block_table + block_idx; // asm volatile("s_load_dword %1, %0, 0x0\n\t" // "s_waitcnt lgkmcnt(0)\n\t": // "+s"(cur_block_table_ptr), // "=s"(cur_block_table)); offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); buffer_load_copy_fp8_tp1(gK, kv_data[0].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1(gK, kv_data[1].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1(gK, kv_data[2].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1(gK, kv_data[3].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1(gK, kv_data[4].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); } for (int n = 0; n < 4; n++) { Fp8_storage v0_0, v0_1; v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128)); v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128)); for (int j = 0; j < 4; j++) { intx2_t v; v[0] = v0_0.fp8_array[j]; v[1] = v0_1.fp8_array[j]; acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[0], v, acco_f32[n * 4 + j], true, false); } v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128 + 32 * 64)); v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128 + 32 * 64)); for (int j = 0; j < 4; j++) { intx2_t v; v[0] = v0_0.fp8_array[j]; v[1] = v0_1.fp8_array[j]; acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[1], v, acco_f32[n * 4 + j], true, false); } } if (block_idx > n_block_min) { __syncthreads(); uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base; // for (int i = 0; i < ) *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[0].data; kv_lds_write_ptr += 64 * 128; *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[1].data; kv_lds_write_ptr += 64 * 128; *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[2].data; kv_lds_write_ptr += 64 * 128; *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[3].data; kv_lds_write_ptr += 64 * 128; *(reinterpret_cast(kv_lds_write_ptr)) = kv_data[4].data; } // asm volatile("s_barrier \n\t"); }; #if 0 #endif #if 1 if constexpr (n_masking_steps == 1) { process_one_block(n_block, IsFirstMaskBlock{}); n_block--; } else { int masking_step = 1; process_one_block(n_block, IsFirstMaskBlock{}); n_block--; for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) { process_one_block(n_block, IsMaskBlock{}); } } for(; n_block >= n_block_min; --n_block) { process_one_block(n_block, IsNoMaskBlock{}); } #endif using ElementO = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; if (NoSplit) { const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { using result_type = cutlass::Array; int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16 + (warpid % 4) * 16; if (row < params.seqlen_q - m_block * kBlockM) { for (int n = 0; n < 4; n++) { col = (tidx % 64 / 16) * 16 + n * 128 + (warpid / 4) * 64; { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].x, 0, acco_f32[n * 4 + 1].x, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].x, 0, acco_f32[n * 4 + 3].x, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].y, 0, acco_f32[n * 4 + 1].y, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].y, 0, acco_f32[n * 4 + 3].y, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].z, 0, acco_f32[n * 4 + 1].z, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].z, 0, acco_f32[n * 4 + 3].z, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].w, 0, acco_f32[n * 4 + 1].w, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].w, 0, acco_f32[n * 4 + 3].w, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; // col += 16; } } } } } } else { constexpr bool Split = true; const int split_offset = __ldg(params.num_splits_ptr + bidb); const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16 + (warpid % 4) * 16; if (row < params.seqlen_q - m_block * kBlockM) { for (int n = 0; n < 4; n++) { col = (tidx % 64 / 16) * 16 + n * 128 + (warp_id / 4) * 64; { gOaccum(row, col) = acco_f32[n * 4 + 0].x; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].x; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].x; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].x; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].y; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].y; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].y; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].y; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].z; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].z; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].z; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].z; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].w; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].w; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].w; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].w; } } } } } } template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; using Element = typename Kernel_traits::Element; using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; const int warp_id = tidx / 64; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{}));//16*576 const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*576 Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*512 Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576 Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512 Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512 Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64 Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64 Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64 typename Kernel_traits::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32 union Fp8_storage { intx4_t data; intx2_t p[2]; int32_t fp8_array[4]; }; #if 0 auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM); __syncthreads(); #else Tensor tSrQ = thr_mma.partition_fragment_A(gQ); lds_direct_copy_qkvfp8_q_tp4(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); uint8_t* q_lds_read_ptr = reinterpret_cast(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 2) * (16 * 64); { int k = 0; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k).storage = q_lds_read_ptr[i]; } q_lds_read_ptr += 32*64; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i]; } // int k = 0; // intx4_t * q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // intx4_t * tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); // q_lds_read_ptr += 64*64; // q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k + 1))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); } asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); { q_lds_read_ptr += 32*64; int k = 2; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k).storage = q_lds_read_ptr[i]; } q_lds_read_ptr += 32*64; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i]; } // q_lds_read_ptr += 64*64; // int k = 2; // intx4_t * q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // intx4_t * tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); // q_lds_read_ptr += 64*64; // q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k + 1))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); } asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); { q_lds_read_ptr += 32*64; int k = 4; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k).storage = q_lds_read_ptr[i]; } q_lds_read_ptr += 32*64; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i]; } // q_lds_read_ptr += 64*64; // int k = 4; // intx4_t * q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // intx4_t * tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); // q_lds_read_ptr += 64*64; // q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k + 1))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); } asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); { q_lds_read_ptr += 32*64; int k = 6; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k).storage = q_lds_read_ptr[i]; } q_lds_read_ptr += 32*64; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i]; } // q_lds_read_ptr += 64*64; // int k = 6; // intx4_t * q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // intx4_t * tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); // q_lds_read_ptr += 64*64; // q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k + 1))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); } asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); { q_lds_read_ptr += 32*64; int k = 8; for (int i = 0; i < 16; i++) { tSrQ(i, 0, k).storage = q_lds_read_ptr[i]; } // q_lds_read_ptr += 64*64; // int k = 8; // intx4_t * q_lds_read_16 = reinterpret_cast(q_lds_read_ptr); // intx4_t * tSrQ_ptr = reinterpret_cast(&(tSrQ(0, 0, k))); // *tSrQ_ptr = *reinterpret_cast(q_lds_read_ptr); } __syncthreads(); #endif auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ // Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; union Val { intx2_t val_to_mmac; int32_t data[2]; }; // Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // clear(acc_o); flash::Softmax<1> softmax; v4f acco_f32[16]; for (int i = 0; i < 16; i++) { acco_f32[i].x = 0.0f; acco_f32[i].y = 0.0f; acco_f32[i].z = 0.0f; acco_f32[i].w = 0.0f; } constexpr static int STAGE = 8; for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); clear(acc_s); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // asm volatile("s_barrier \n\t"); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); lds_direct_copy_qkvfp8(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8(gK, sK, 8, params.k_row_stride, seqlen_k - n_block * kBlockN); asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s); asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s); asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2)); cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s); asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3)); cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4)); cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5)); cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6)); cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7)); cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8)); cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s); gK.data() = gK.data() + (-offset_k); // asm volatile("s_barrier \n\t"); Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } // asm volatile("s_barrier \n\t"); { const bool is_first_masking_step = masking_step == 0; // is_first_masking_step // ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2) // : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2); is_first_masking_step ? softmax.template softmax_rescale_o_fp8_tp4(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32) : softmax.template softmax_rescale_o_fp8_tp4(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32); } // asm volatile("s_barrier \n\t"); // if (block0() && tidx < 64) // { // // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w, // // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w, // // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w, // // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w // // ); // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3), // acc_s(4), acc_s(5), acc_s(6), acc_s(7) // // acc_s(8), acc_s(9), acc_s(10), acc_s(11), // // acc_s(12), acc_s(13), acc_s(14), acc_s(15) // ); // } #if 1 Fp8_storage p_fp8; { __builtin_amdgcn_sched_barrier(0); int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; int32_t result; result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true); // int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64])); // *lds_ptr = result; int32_t result1; result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false); result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true); // lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8])); // *lds_ptr = result1; int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 2) * 4 + (warp_id % 2) * 16 * 64])); *lds_ptr = result; int32_t* lds_ptr1 = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 2) * 4 + (warp_id % 2) * 16 * 64 + 8])); *lds_ptr1 = result1; __syncthreads(); p_fp8.data = *reinterpret_cast(&(sP[tid * 16 + (warp_id % 2) * 16 * 64])); __builtin_amdgcn_sched_barrier(0); } { __builtin_amdgcn_sched_barrier(0); for (int i = 0; i < 4; i++) { { int k = 0; Fp8_storage v0_0, v0_1; v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k)))); v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1)))); // if (block0() && tidx < 64) // { // float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0); // float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1); // float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2); // float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3); // printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]); // } for (int j = 0; j < 4; j++) { Val tmp; tmp.data[0] = v0_0.fp8_array[j]; tmp.data[1] = v0_1.fp8_array[j]; acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false); } } { int k = 2; Fp8_storage v0_0, v0_1; v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k)))); v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1)))); // if (block0() && tidx < 64) // { // float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0); // float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1); // float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2); // float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3); // printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]); // } for (int j = 0; j < 4; j++) { Val tmp; tmp.data[0] = v0_0.fp8_array[j]; tmp.data[1] = v0_1.fp8_array[j]; acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false); } } } __builtin_amdgcn_sched_barrier(0); } // if (block0()) // { // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w, // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w, // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w, // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w // ); // // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3), // // acc_s(4), acc_s(5), acc_s(6), acc_s(7), // // acc_s(8), acc_s(9), acc_s(10), acc_s(11), // // acc_s(12), acc_s(13), acc_s(14), acc_s(15) // // ); // } // asm volatile("s_barrier \n\t"); #endif } using ElementO = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; const int split_offset = __ldg(params.num_splits_ptr + bidb); // Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{}); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; if (NoSplit) { constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_fp8_tp4(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } // if (tidx == 1) // { // printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w); // } { using result_type = cutlass::Array; int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16 + (warpid % 2) * 16; if (row < params.seqlen_q - m_block * kBlockM) { for (int n = 0; n < 4; n++) { col = (tidx % 64 / 16) * 16 + n * 128 + (warpid / 2) * 64; { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].x, 0, acco_f32[n * 4 + 1].x, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].x, 0, acco_f32[n * 4 + 3].x, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].y, 0, acco_f32[n * 4 + 1].y, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].y, 0, acco_f32[n * 4 + 3].y, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].z, 0, acco_f32[n * 4 + 1].z, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].z, 0, acco_f32[n * 4 + 3].z, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; col += 4; } { auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].w, 0, acco_f32[n * 4 + 1].w, 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].w, 0, acco_f32[n * 4 + 3].w, 0); auto res0 = reinterpret_cast(d0); auto res1 = reinterpret_cast(d1); gOaccum(row, col) = res0[0]; gOaccum(row, col + 1) = res0[1]; gOaccum(row, col + 2) = res1[0]; gOaccum(row, col + 3) = res1[1]; // col += 16; } // for (int j = 0; j < 4; j++) // { // auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0); // auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0); // auto res0 = reinterpret_cast(d0); // auto res1 = reinterpret_cast(d1); // gOaccum(row, col) = res0[0]; // gOaccum(row, col + 1) = res0[1]; // gOaccum(row, col + 2) = res1[0]; // gOaccum(row, col + 3) = res1[1]; // col += 16; // } } // for (int n = 0; n < 8; n++) // { // using result_type = cutlass::Array; // auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0); // auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0); // col = (tidx % 64 / 16) * 4 + n * 64; // auto res0 = reinterpret_cast(d0); // auto res1 = reinterpret_cast(d1); // gOaccum(row, col) = res0[0]; // gOaccum(row, col + 1) = res0[1]; // gOaccum(row, col + 2) = res1[0]; // gOaccum(row, col + 3) = res1[1]; // } } } } } else { constexpr bool Split = true; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_fp8_tp4(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16 + (warpid % 2) * 16; if (row < params.seqlen_q - m_block * kBlockM) { // for (int n = 0; n < 32; n++) // { // col = (tidx % 64 / 16) * 4 + n * 16; // gOaccum(row, col) = acco_f32[n].x; // gOaccum(row, col + 1) = acco_f32[n].y; // gOaccum(row, col + 2) = acco_f32[n].z; // gOaccum(row, col + 3) = acco_f32[n].w; // } for (int n = 0; n < 4; n++) { col = (tidx % 64 / 16) * 16 + n * 128 + (warp_id / 2) * 64; { gOaccum(row, col) = acco_f32[n * 4 + 0].x; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].x; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].x; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].x; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].y; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].y; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].y; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].y; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].z; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].z; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].z; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].z; col += 4; } { gOaccum(row, col) = acco_f32[n * 4 + 0].w; gOaccum(row, col + 1) = acco_f32[n * 4 + 1].w; gOaccum(row, col + 2) = acco_f32[n * 4 + 2].w; gOaccum(row, col + 3) = acco_f32[n * 4 + 3].w; } // for (int j = 0; j < 4; j++) { // gOaccum(row, col) = acco_f32[n * 4 + j].x; // gOaccum(row, col + 1) = acco_f32[n * 4 + j].y; // gOaccum(row, col + 2) = acco_f32[n * 4 + j].z; // gOaccum(row, col + 3) = acco_f32[n * 4 + j].w; // col += 16; // } } } } } // Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // for (int n = 0; n < 8; n++) // { // acc_o(0, 0, n) = acco_f32[n * 2].x; // acc_o(1, 0, n) = acco_f32[n * 2].y; // acc_o(2, 0, n) = acco_f32[n * 2].z; // acc_o(3, 0, n) = acco_f32[n * 2].w; // acc_o(4, 0, n) = acco_f32[n * 2 + 1].x; // acc_o(5, 0, n) = acco_f32[n * 2 + 1].y; // acc_o(6, 0, n) = acco_f32[n * 2 + 1].z; // acc_o(7, 0, n) = acco_f32[n * 2 + 1].w; // } // if (NoSplit) // store_float8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax); // else // store_float8(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax); } template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel_fp8(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); float descale_k = 1.f; float scale_softmax = params.scale_softmax; float scale_softmax_log2 = params.scale_softmax_log2; float descale_q = __ldg(params.descale_q_ptr); descale_k = __ldg(params.descale_k_ptr); scale_softmax = scale_softmax * descale_q * descale_k; scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } #if defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } #endif } } template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel_fp8_tp1(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); float descale_k = 1.f; float scale_softmax = params.scale_softmax; float scale_softmax_log2 = params.scale_softmax_log2; float descale_q = __ldg(params.descale_q_ptr); descale_k = __ldg(params.descale_k_ptr); scale_softmax = scale_softmax * descale_q * descale_k; scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } #if defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP1(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } #endif } } template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1) flash_fwd_splitkv_mla_kernel_fp8_tp4(const Flash_fwd_mla_params params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); float descale_k = 1.f; float scale_softmax = params.scale_softmax; float scale_softmax_log2 = params.scale_softmax_log2; float descale_q = __ldg(params.descale_q_ptr); descale_k = __ldg(params.descale_k_ptr); scale_softmax = scale_softmax * descale_q * descale_k; scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = *(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } #if defined(__gfx938__) { flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } #endif } } } // namespace flash //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_flash_splitkv_fwd_mla_fp8_tp1(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); // const static bool enable_asm = get_env_("FLASH_MLA_ENABLE_ASM"); if (1) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8_tp1; constexpr size_t smem_size = 65536; // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); } // else { // static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR"); // assert(FLASH_MLA_ASM_DIR != nullptr && "FLASH_MLA_ASM_DIR nullptr \n"); // constexpr size_t smem_size = 65536; // std::string co_file = std::string(FLASH_MLA_ASM_DIR) + // "flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co"; // hipError_t status = hipSuccess; // static hipModule_t fwd_module_sample; // static bool IS_FWD_MODULE_LOADED = false; // if (IS_FWD_MODULE_LOADED == false) // { // status = hipModuleLoad(&fwd_module_sample, co_file.c_str()); // if (status not_eq hipSuccess) { // printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str()); // return; // } // IS_FWD_MODULE_LOADED = true; // } // size_t params_size = sizeof(params); // void* config[] = { // HIP_LAUNCH_PARAM_BUFFER_POINTER, // ¶ms, // HIP_LAUNCH_PARAM_BUFFER_SIZE, // ¶ms_size, // HIP_LAUNCH_PARAM_END // }; // dim3 grid(num_m_block, params.h, params.num_sm_parts); // std::string kernel_name = params.is_causal ? // "_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params": // "_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params"; // hipFunction_t flash_mla_func; // status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str()); // status = hipModuleLaunchKernel( // flash_mla_func, // grid.x, grid.y, grid.z, // Kernel_traits::kNThreads, 1, 1, // smem_size, // shared memory // stream, // stream // NULL, // (void**)&config // ); // if (status not_eq hipSuccess) { // printf("[flashmla] EXIT: failed to launch kernel!\n"); // return; // } // } CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_splitkv_fwd_mla_fp8_tp4(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8_tp4; constexpr size_t smem_size = 65536; // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_splitkv_fwd_mla_fp8(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); const static bool enable_asm = get_env_("FLASH_MLA_ENABLE_ASM"); if (Kernel_traits::IS_WITH_CAT || !enable_asm) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8; constexpr size_t smem_size = 32768; // CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params); }); } else { static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR"); constexpr size_t smem_size = 32768; std::string co_file = std::string(FLASH_MLA_ASM_DIR) + "flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co"; hipError_t status = hipSuccess; static hipModule_t fwd_module_sample; static bool IS_FWD_MODULE_LOADED = false; if (IS_FWD_MODULE_LOADED == false) { status = hipModuleLoad(&fwd_module_sample, co_file.c_str()); if (status not_eq hipSuccess) { printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str()); // return 0; } IS_FWD_MODULE_LOADED = true; } size_t params_size = sizeof(params); void* config[] = { HIP_LAUNCH_PARAM_BUFFER_POINTER, ¶ms, HIP_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, HIP_LAUNCH_PARAM_END }; dim3 grid(num_m_block, params.h, params.num_sm_parts); std::string kernel_name = params.is_causal ? "_ZN5flash32flash_fwd_splitkv_mla_kernel_fp8I34Flash_fwd_kernel_traits_mla_qkvfp8ILi576ELi16ELi64ELi4EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_22SharedStorageMLAFloat8IS5_EEEEv20Flash_fwd_mla_params": "_ZN5flash32flash_fwd_splitkv_mla_kernel_fp8I34Flash_fwd_kernel_traits_mla_qkvfp8ILi576ELi16ELi64ELi4EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_22SharedStorageMLAFloat8IS5_EEEEv20Flash_fwd_mla_params"; hipFunction_t flash_mla_func; status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str()); status = hipModuleLaunchKernel( flash_mla_func, grid.x, grid.y, grid.z, Kernel_traits::kNThreads, 1, 1, smem_size, // shared memory stream, // stream NULL, (void**)&config ); if (status not_eq hipSuccess) { printf("[flashmla] EXIT: failed to launch kernel!\n"); // return 0; } } CHECK_CUDA_KERNEL_LAUNCH(); dim3 grid_combine(params.b * params.h * params.seqlen_q); constexpr int kNThreads = 128; MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params ¶ms, cudaStream_t stream, bool is_with_cat) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV // printf(" params.ngroups = %d \n", params.ngroups); if (is_with_cat) { if constexpr (std::is_same_v) { using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, T, To, 512, true>; run_flash_splitkv_fwd_mla_fp8>(params, stream); } else { // q为bf16 using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, cutlass::float_e4m3_t, To, 512, true, T>; run_flash_splitkv_fwd_mla_fp8>(params, stream); } return; } if constexpr (std::is_same_v) { if (params.ngroups >= 64) { using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP1<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla_fp8_tp1>(params, stream); } else if (params.ngroups > 16) { using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP4<576, 32, 64, 4, T, To, 512>; run_flash_splitkv_fwd_mla_fp8_tp4>(params, stream); } else { using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, T, To, 512, false>; run_flash_splitkv_fwd_mla_fp8>(params, stream); } } }