#include #include "utils.h" #include "params.h" #include "config.h" #include "traits.h" #include "softmax.h" using namespace cute; namespace sm90 { template __device__ void compute_attn_1rowblock_splitkv_mla_qkvfp8_gfx938(const DenseAttnDecodeParams_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit) { constexpr int kBlockM = T::kBlockM; constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; constexpr int kHeadDimV = T::kHeadDimV; const int tidx = threadIdx.x; extern __shared__ char shared_memory[]; using SharedMemoryPlan = typename T::SharedMemoryPlan; SharedMemoryPlan &plan = *reinterpret_cast(shared_memory); using index_t = int64_t; using Element = typename T::Element; const index_t row_offset_k = (bidh) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*576 Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{}));//64*512 Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{}); //64,512 Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{}); //16,576 Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{});//64,512 Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{}); //16*64 Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{});//512,64 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{}); //64 Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{}); typename T::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64 typename T::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32 const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{}));//16*576 flash::lds_direct_copy_qkvfp8(gQ, sQ, 0, params.q_row_stride, params.q_seq_per_hk - m_block * kBlockM); flash::lds_direct_copy_qkvfp8(gQ, sQ, 1, params.q_row_stride, params.q_seq_per_hk - m_block * kBlockM); flash::lds_direct_copy_qkvfp8(gQ, sQ, 2, params.q_row_stride, params.q_seq_per_hk - m_block * kBlockM); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); __syncthreads(); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; constexpr static int k1_loops = size<2>(tOrVt); flash::Softmax<1> softmax; union Fp8_storage { intx4_t data; intx2_t p[2]; }; v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1; c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f; c0_1.x = 0.0f; c0_1.y = 0.0f; c0_1.z = 0.0f; c0_1.w = 0.0f; c1_0.x = 0.0f; c1_0.y = 0.0f; c1_0.z = 0.0f; c1_0.w = 0.0f; c1_1.x = 0.0f; c1_1.y = 0.0f; c1_1.z = 0.0f; c1_1.w = 0.0f; c2_0.x = 0.0f; c2_0.y = 0.0f; c2_0.z = 0.0f; c2_0.w = 0.0f; c2_1.x = 0.0f; c2_1.y = 0.0f; c2_1.z = 0.0f; c2_1.w = 0.0f; c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f; c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f; for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); clear(acc_s); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #if 1 flash::lds_direct_copy_qkvfp8(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy_qkvfp8(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN); constexpr static int BUFFER_SIZE = 1; uint128_t buffer[BUFFER_SIZE]; flash::buffer_load_copy_qkvfp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s); asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s); asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2)); cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s); asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3)); cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4)); cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5)); cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6)); cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7)); Fp8_storage v3_0, v3_1; flash::__ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data); flash::__ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data); cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s); asm volatile("s_waitcnt vmcnt(0) \n\t"); flash::buffer_to_tensor(buffer[0], tSrK, 8); cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); #else #endif gK.data() = gK.data() + (-offset_k); // if (thread0()) // { // printf(" %.2f %.2f \n", acc_s(0), acc_s(1)); // } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!T::Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (q_seq_per_hk - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (q_seq_per_hk - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o_fp8(acc_s, sRow_max_reduce_buffer, params.scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1) : softmax.template softmax_rescale_o_fp8(acc_s, sRow_max_reduce_buffer, params.scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1); } Fp8_storage data_fp8; { int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; int32_t result; result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0), acc_s(1), result, false); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2), acc_s(3), result, true); int32_t* lds_ptr = reinterpret_cast(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0])); *lds_ptr = result; __syncthreads(); data_fp8.data = *reinterpret_cast(&(sP[tid * 16])); } { Fp8_storage v0_0, v0_1; Fp8_storage v1_0, v1_1; Fp8_storage v2_0, v2_1; __builtin_amdgcn_sched_barrier(0); flash::__ds_read_m32x32_row_col_rrow<0, 0, 0>(tOsVt, v0_0.data); flash::__ds_read_m32x32_row_col_rrow<1, 0, 1>(tOsVt, v1_0.data); flash::__ds_read_m32x32_row_col_rrow<2, 0, 2>(tOsVt, v2_0.data); c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[0], c3_0, true, false); c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[1], c3_1, true, false); c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[0], c3_0, true, false); c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[1], c3_1, true, false); c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[0], c0_0, true, false); c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[1], c0_1, true, false); c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[0], c1_0, true, false); c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[1], c1_1, true, false); c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[0], c2_0, true, false); c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[1], c2_1, true, false); flash::__ds_read_m32x32_row_col_rrow<0, 1, 0>(tOsVt, v0_1.data); flash::__ds_read_m32x32_row_col_rrow<1, 1, 1>(tOsVt, v1_1.data); flash::__ds_read_m32x32_row_col_rrow<2, 1, 2>(tOsVt, v2_1.data); c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[0], c0_0, true, false); c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[1], c0_1, true, false); c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[0], c1_0, true, false); c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[1], c1_1, true, false); c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[0], c2_0, true, false); c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[1], c2_1, true, false); __builtin_amdgcn_sched_barrier(0); } } Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w; acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w; acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w; acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w; acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w; acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w; acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w; acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w; if (NoSplit) { using ElementO = cutlass::bfloat16_t; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_fp8(acc_o, sRow_sum_reduce_buffer, params.scale_softmax, 1.0f); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // } Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); Tensor rO = flash::convert_type(acc_o); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) * 4 + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei +=4) { gOaccum(row, col) = rO(ei, m, n); gOaccum(row, col + 1) = rO(ei + 1, m, n); gOaccum(row, col + 2) = rO(ei + 2, m, n); gOaccum(row, col + 3) = rO(ei + 3, m, n); col += 16; } } } } } } else { using ElementO = float; int split_idx = params.num_splits_ptr[bidb] + n_split_idx; constexpr bool Split = true; const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor lse = softmax.template normalize_softmax_lse_fp8(acc_o, sRow_sum_reduce_buffer, params.scale_softmax, 1.0f); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) * 4 + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei +=4) { gOaccum(row, col) = acc_o(ei, m, n); gOaccum(row, col + 1) = acc_o(ei + 1, m, n); gOaccum(row, col + 2) = acc_o(ei + 2, m, n); gOaccum(row, col + 3) = acc_o(ei + 3, m, n); col += 16; } } } } } } } template __global__ void __launch_bounds__(T::NUM_THREADS, 1) flash_fwd_splitkv_mla_qkvfp8_kernel(const DenseAttnDecodeParams_fp8 params) { const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; // if (thread0()) // { // printf("m_block = %d sched_meta.begin_req_idx = %d \n ", m_block, sched_meta.begin_req_idx); // } if (sched_meta.begin_req_idx >= params.b) return; for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN); const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); if (batch_idx > sched_meta.begin_req_idx) { __syncthreads(); } #if defined(__gfx938__) compute_attn_1rowblock_splitkv_mla_qkvfp8_gfx938(params, batch_idx, bidh, m_block, n_split_idx, seqlen_k, start_block_idx, end_block_idx, is_no_split ); #endif } } template void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams_fp8 ¶ms) { FLASH_ASSERT(params.d == Config::HEAD_DIM_K); FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); constexpr size_t smem_size = 32768; // printf("batch_idx = %d\n ", smem_size); // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // #if defined(__gfx938__) BOOL_SWITCH(params.is_causal, Is_causal, [&] { using T = Traits; const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); auto mla_kernel = &flash_fwd_splitkv_mla_qkvfp8_kernel; mla_kernel<<>>(params); }); // #endif // cudaLaunchConfig_t mla_kernel_config = { // dim3(num_m_block, params.h_k, params.num_sm_parts), // dim3(T::NUM_THREADS, 1, 1), // smem_size, // params.stream, // mla_kernel_attributes, // 1 // }; // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); CHECK_CUDA_KERNEL_LAUNCH(); } }