#include #include "utils.h" #include "params.h" #include "config.h" #include "traits.h" #include "softmax.h" using namespace cute; namespace gfx93 { template __device__ void compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit) { #if 0 constexpr static bool Is_causal = T::Is_causal; constexpr int kBlockM = T::kBlockM; constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; constexpr int kHeadDimV = T::kHeadDimV; const int tidx = threadIdx.x; const int lane_idx = tidx % 64; extern __shared__ char shared_memory[]; using SharedMemoryPlan = typename T::SharedMemoryPlan; SharedMemoryPlan &plan = *reinterpret_cast(shared_memory); using index_t = int64_t; using Element = typename T::Element; const index_t row_offset_k = (bidh) * params.k_head_stride; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle_fp8 = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle_fp8{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{}); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{}); typename T::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename T::TiledMma_16_16_32 tiled_mma_16_16_32; auto thr_mma_16_16_32 = tiled_mma_16_16_32.get_thread_slice(tidx); typename T::TiledMma_O_16_32_16 tiled_mma_o_16_32_16; auto thr_mma_o_16_32_16 = tiled_mma_o_16_32_16.get_thread_slice(tidx); typename T::TiledMma_int8 tiled_mma_int8; auto thr_mma_int8 = tiled_mma_int8.get_thread_slice(tidx); typename T::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); // 过lds读取q, 由于q是4个warp共用的 typename T::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (threadIdx.x < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.q_seq_per_hk - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma_16_16_32); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(gK); Tensor tSrK_int8 = thr_mma_int8.partition_fragment_B(gK); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o_16_32_16); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle_fp8); constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; const auto sk_data = sK.data(); const auto sRow_max_reduce_buffer_data = sRow_max_reduce_buffer.data(); constexpr auto sk_size = size(sK); const auto sP_data = sP.data(); const auto tSsK_data = tSsK.data(); const auto tOsVt_data = tOsVt.data(); const auto gK_data = gK.data(); constexpr static int BUFFER_SIZE = 1; constexpr short int wait_cnt = 8; { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = n_block % 2 == 1 ? sk_data + sk_size : sk_data; #pragma unroll for (int i = 0; i < 8; i++) { flash::lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } } constexpr static Fp8KVCacheDataType KV_DTYPE = Fp8KVCacheDataType::kFp8E5M2; constexpr static bool is_scale_equal_one = true; const float k_scale = 1.0; Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; flash::buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN); #if 0 #else flash::gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); // asm volatile("s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); flash::gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); #endif // asm volatile("s_barrier\n\t"); } // if (block0()) { // printf(" tid = %d %.2f %.2f %.2f %.2f \n",tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3)); // } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { // if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v) { // acc_s(i) *= k_scale; // } if constexpr (!Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { flash::lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else flash::gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif } } for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8); { tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data; uint32x4_t buffer[BUFFER_SIZE]; flash::buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - (n_block - 1) * kBlockN); #if 0 #else flash::gemm_rs_fp8(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); #endif // asm volatile("s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); flash::gemm_k_rs_fp8(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); // asm volatile("s_barrier\n\t"); } sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data; Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); if (n_block > n_block_min) { int cur_block_table; const int *cur_block_table_ptr = block_table + n_block - 1; // cur_block_table = block_table[n_block - 1]; asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK_data + (offset_k); sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data; // sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data; // sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data; // tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data; #pragma unroll for (int i = 0; i < 8; i++) { flash::lds_direct_copy_fp8(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN); } // buffer_load_copy_fp8(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN); // gK.data() = gK.data() + (-offset_k); } { tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data; #if 0 #else flash::gemm1_rs_fp8(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale); #endif // tOsVt.data() = (n_block - 1) % 2 ? tOsVt_data + sk_size : tOsVt_data; } } using ElementAccum = float; if (NoSplit) { using ElementO = Element; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // } Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); Tensor rO = flash::convert_type(acc_o); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) * 2 + warpid * 64 + n * 256; for (int ei = 0; ei < 16; ei += 2) { gOaccum(row, col) = rO(ei, m, n); gOaccum(row, col + 1) = rO(ei + 1, m, n); col += 8; } } } } } } else { using ElementO = float; int split_idx = params.num_splits_ptr[bidb] + n_split_idx; constexpr bool Split = true; const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) * 2 + warpid * 64 + n * 256; for (int ei = 0; ei < 16; ei += 2) { gOaccum(row, col) = acc_o(ei, m, n); gOaccum(row, col + 1) = acc_o(ei + 1, m, n); col += 8; } } } } } } #endif } template __global__ void __launch_bounds__(T::NUM_THREADS, 1) flash_fwd_splitkv_mla_kvfp8_kernel(const DenseAttnDecodeParams_fp8 params) { const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; // if (thread0()) // { // printf("m_block = %d sched_meta.begin_req_idx = %d \n ", m_block, sched_meta.begin_req_idx); // } if (sched_meta.begin_req_idx >= params.b) return; for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN); const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); if (batch_idx > sched_meta.begin_req_idx) { __syncthreads(); } #if defined(__gfx936__) || defined(__gfx938__) compute_attn_1rowblock_splitkv_mla_kvfp8(params, batch_idx, bidh, m_block, n_split_idx, seqlen_k, start_block_idx, end_block_idx, is_no_split ); #endif } } template void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 ¶ms) { FLASH_ASSERT(params.d == Config::HEAD_DIM_K); FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); constexpr size_t smem_size = 65536; BOOL_SWITCH(params.is_causal, Is_causal, [&] { using T = Traits; const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); auto mla_kernel = &flash_fwd_splitkv_mla_kvfp8_kernel; mla_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } }