Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
......@@ -284,4 +284,560 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_int8_prefix_prefill_kernel(c
compute_attn_int8_prefix_prefill_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// GFX938 kernels
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool FP8_DEBUG, bool Is_even_MN, int kBlockM, int kBlockN, int WARP_M, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_debug_p_reg(
Element* p_reg_ptr,
union_vec32_fp8 p_reg[WARP_M / 16],
int bidb,
int bidh,
int h,
int actual_seqlen_q,
int actual_seqlen_k,
int max_seq_q_offset,
int max_seq_kv_offset,
int m_block,
int n_block_loop,
int warp_id,
int lane_id
) {
if constexpr (FP8_DEBUG) {
__builtin_amdgcn_sched_barrier(0);
if constexpr (FP8_DEBUG) {
Element* p_reg_buffer = p_reg_ptr + (bidb * h + bidh) * actual_seqlen_q * actual_seqlen_k;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
int row_pos = m_block * kBlockM + warp_id * WARP_M + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);
int col_pos = (lane_id >> 4) * 8 + n_idx * 4 + k_loop * WARP_N + n_block_loop * kBlockN;
*(int32_t*)(p_reg_buffer + row_pos * actual_seqlen_k + col_pos) = p_reg[m_idx].i32[k_loop * 2 + n_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__builtin_amdgcn_sched_barrier(0);
}
}
#include "fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h"
#include "fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h"
#include "fwd/gfx938/fp8_softmax_gfx938.h"
#include "fwd/gfx938/fp8_epilogue.h"
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_fp8_attn_mha_1rowblock_gfx938(const Params &params, const int bidb, const int bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using Element_k = typename Kernel_traits::Element_k;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int WARP_K = 32;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
// 获取当前 TG 处理的任务大小
const flash::BlockInfo</*Varlen=*/Is_Varlen> binfo(params, bidb);
// 判断任务边界
int max_seq_q_offset = binfo.actual_seqlen_q - m_block * kBlockM;
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k <= 0/* || bidh >= h*/) return;
// 获取 wave id
// int __warp_id = threadIdx.x >> 6;
// int warp_id = __builtin_amdgcn_readfirstlane(__warp_id);
// 定义 lds, 128x128 个 fp8, 16384 bytes
// __shared__ int8_t lds[16384 + 4096 + 16384 + 4096];
extern __shared__ int8_t lds[];
int8_t* q_lds = lds + 0;
int8_t* k_lds = lds + 0;
int8_t* v_lds = lds + 0;
// ========================================== 计算 offset ===========================================
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int64_t row_offset_lse_base;
if constexpr (Is_Varlen) {
if constexpr (Layout == 1) { /* bshd: q/o are [total_q, h, d] */
row_offset_q = (int64_t(binfo.sum_s_q) + m_block * kBlockM) * int64_t(params.q_row_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(binfo.sum_s_k) * int64_t(params.k_row_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(binfo.sum_s_k) * int64_t(params.v_row_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = int64_t(binfo.sum_s_q) * int64_t(params.o_head_stride) * params.h + params.o_head_stride * bidh + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = bidh * int64_t(params.total_q) + binfo.sum_s_q;
} else { /* bhsd */
row_offset_q = int64_t(binfo.sum_s_q) * int64_t(params.q_row_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(params.q_row_stride);
row_offset_k = int64_t(binfo.sum_s_k) * int64_t(params.k_row_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(binfo.sum_s_k) * int64_t(params.v_row_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = int64_t(binfo.sum_s_q) * int64_t(params.o_row_stride) + bidh * params.o_head_stride + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = bidh * int64_t(params.total_q) + binfo.sum_s_q;
}
} else {
row_offset_q = bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(params.q_row_stride);
row_offset_k = bidb * int64_t(params.k_batch_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = bidb * int64_t(params.v_batch_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = (bidb * params.h + bidh) * int64_t(binfo.actual_seqlen_q);
}
int row_offset_q_descale = bidb * params.q_descale_batch_stride + bidh * params.q_descale_head_stride;
int row_offset_k_descale = bidb * params.k_descale_batch_stride + int(bidh / params.h_h_k_ratio) * params.k_descale_head_stride;
int row_offset_v_descale = bidb * params.v_descale_batch_stride + int(bidh / params.h_h_k_ratio) * params.v_descale_head_stride;
Element_k* q_ptr = reinterpret_cast<Element_k*>(params.q_ptr) + row_offset_q;
Element_k* k_ptr = reinterpret_cast<Element_k*>(params.k_ptr) + row_offset_k;
Element_k* v_ptr = reinterpret_cast<Element_k*>(params.v_ptr) + row_offset_v;
ElementAccum* q_descale_ptr = reinterpret_cast<ElementAccum*>(params.q_descale_ptr);
ElementAccum* k_descale_ptr = reinterpret_cast<ElementAccum*>(params.k_descale_ptr);
ElementAccum* v_descale_ptr = reinterpret_cast<ElementAccum*>(params.v_descale_ptr);
ElementAccum q_descale = q_descale_ptr[row_offset_q_descale];
ElementAccum k_descale = k_descale_ptr[row_offset_k_descale];
ElementAccum qk_descale = q_descale * k_descale;
ElementAccum softmax_scale = params.scale_softmax * qk_descale;
ElementAccum softmax_scale_log2 = params.scale_softmax_log2 * qk_descale;
ElementAccum v_descale = v_descale_ptr[row_offset_v_descale];
// acc_o_ptr = reinterpret_cast<ElementAccum*>(acc_o_ptr) + row_offset_o;
ElementAccum* softmax_lse_ptr = reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr);
Element_k* p_reg_ptr = reinterpret_cast<Element_k *>(params.p_ptr);
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds<Is_even_MN, kHeadDim, WARP_M, Element_k>(q_ptr, q_lds, warp_id, params.q_row_stride, max_seq_q_offset);
// 计算解决 bank 冲突必须的一些变量
int tx = threadIdx.x;
int lane_id = tx & 63;
// 准备存储最大值, 求和, acc_o 寄存器 等
ElementAccum scores_max[WARP_M / 16];
ElementAccum scores_sum[WARP_M / 16];
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16];
fp8_attention_initialize<kHeadDim, WARP_M, WARP_N, ElementAccum>(scores_max, scores_sum, acc_o);
// 从 lds 读取 q 的数据, 不需要同步
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64];
load_q_from_lds_to_vgpr<kHeadDim, WARP_M, Element_k>(q_regs, q_lds, warp_id, lane_id);
// ======================================================== Prefetch K ======================================================================
fp8_prefetch_k_to_lds<Is_even_MN, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, binfo.actual_seqlen_k);
// ======================================================== Mainloop ======================================================================
// 计算当前 block 计算任务的边界,带 causal mask 的场景可以少计算一些
int n_block_min = 0;
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal) {
n_block_max = std::min(n_block_max, ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + 0/*params.window_size_right*/, kBlockN));
}
constexpr int n_masking_steps = (!Is_causal/* && !Is_local*/) ? 1: ceil_div(kBlockM, kBlockN); // 目前的场景可能需要限制 kBlockM == kBlockN, 主要是考虑到 prefetch K 的数据正确性
constexpr bool Assume_valid_rows = !Is_local && (!Is_causal || !Is_Varlen);
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
// 计算 kv 的边界
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// ======================================================== QK gemm ======================================================================
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds<Is_even_MN, kBlockN, kHeadDim, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== Softmax ======================================================================
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDim, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// ========================================== cvt ===============================================
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<1, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
fp8_pv_gemm_and_prefetch_k<true/*PrefetchK*/, Is_even_MN, kHeadDim, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset - kBlockN);
// 计算 k, v 的偏移
v_ptr += kBlockN * params.v_row_stride;
}
// ========================================== Rest ===============================================
// 剩下的需要做 causal mask
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
// 计算 kv 的边界
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// ======================================================== QK gemm ======================================================================
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds<Is_even_MN, kBlockN, kHeadDim, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// ======================================================== causal mask ==================================================================
if constexpr (Is_causal) {
fp8_apply_causal_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN, lane_id);
}
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== mask ==================================================================
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if constexpr (!Is_causal && !Is_local) {
if constexpr (!Is_even_MN) {
fp8_apply_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, max_seq_kv_offset, 0, lane_id);
}
}
// ======================================================== Softmax ======================================================================
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDim, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// ========================================== cvt ===============================================
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<0, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
constexpr bool PrefetchK = n_masking_steps > 1;
fp8_pv_gemm_and_prefetch_k<PrefetchK, Is_even_MN, kHeadDim, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset - kBlockN);
// 计算 k, v 的偏移
if (not PrefetchK) {
k_ptr += kBlockN * params.k_row_stride;
}
v_ptr += kBlockN * params.v_row_stride;
}
// ========================================== rescale by scores_sum ==========================================
// 根据 scores_sum 对 acc_o 做缩放
ElementAccum lse[WARP_M / 16];
if constexpr (Return_softmax) {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDim, WARP_M, WARP_N, true/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
} else {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDim, WARP_M, WARP_N, false/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
}
// ========================================== lse storation ==========================================
if constexpr (Return_softmax) {
fp8_epilogue_store_lse<Is_even_MN, WARP_M, ElementAccum>(
softmax_lse_ptr, scores_max, scores_sum, lse, row_offset_lse_base, binfo.actual_seqlen_q, m_block * kBlockM + warp_id * WARP_M, lane_id);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output<Is_even_MN, kBlockM, kHeadDim, WARP_M, WARP_N, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, params.o_row_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, bool Is_GQA, int Layout, typename Params>
inline __device__ void compute_fp8_attn_gfx938(const Params &params) {
#if defined(__gfx938__) || defined(__gfx946__)
constexpr bool Do_lpt = Is_causal and Is_GQA;
const int bidh = Do_lpt ? blockIdx.x : blockIdx.y;
const int bidb = Do_lpt ? blockIdx.y : blockIdx.z;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int m_block = Do_lpt ? gridDim.z - 1 - blockIdx.z : blockIdx.x;
flash::compute_fp8_attn_mha_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
if constexpr (Is_causal and !Is_GQA /*MHA causal mask*/) {
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
flash::compute_fp8_attn_mha_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, gridDim.x * 2 - 1 - m_block, warp_id);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 Prefix Prefill (paged KV cache + varlen) for GFX938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_fp8_attn_prefix_prefill_1rowblock_gfx938(const Params &params, const int bidb, const int bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using Element_k = typename Kernel_traits::Element_k;
using ElementAccum = typename Kernel_traits::ElementAccum;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int WARP_K = 32;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
// Varlen BlockInfo
const flash::BlockInfo<true/*Varlen*/, false/*Is_kvcache*/> binfo(params, bidb);
int max_seq_q_offset = binfo.actual_seqlen_q - m_block * kBlockM;
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k <= 0) return;
// 定义 lds
extern __shared__ int8_t lds[];
int8_t* q_lds = lds + 0;
int8_t* k_lds = lds + 0;
int8_t* v_lds = lds + 0;
// ========================================== 计算 offset (varlen + paged) ===========================================
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block_min = 0;
if constexpr (Is_local) {
n_block_min = max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
}
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
const int window_size_right = Is_local ? params.window_size_right : 0;
n_block_max = min(n_block_max, ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + window_size_right, kBlockN));
}
if (n_block_min >= n_block_max) return;
const int first_block_table_idx = n_block_min * kBlockN / params.page_block_size;
const int first_block_table_offset = n_block_min * kBlockN - first_block_table_idx * params.page_block_size;
const int first_page = block_table[first_block_table_idx];
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int row_offset_lse;
if constexpr (Layout == 1) { /*bshd layout*/
row_offset_q = (binfo.sum_s_q + m_block * kBlockM) * int64_t(params.q_row_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(first_page) * int64_t(params.k_batch_stride) + first_block_table_offset * int64_t(params.k_row_stride) + (bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(first_page) * int64_t(params.v_batch_stride) + first_block_table_offset * int64_t(params.v_row_stride) + (bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = binfo.sum_s_q * int64_t(params.o_head_stride) * params.h + params.o_head_stride * bidh + m_block * kBlockM * params.o_row_stride;
row_offset_lse = bidh * params.total_q + binfo.sum_s_q;
} else { /*bhsd layout*/
row_offset_q = binfo.sum_s_q * int64_t(params.q_row_stride) + bidh * params.q_head_stride + m_block * kBlockM * params.q_row_stride;
row_offset_k = int64_t(first_page) * int64_t(params.k_batch_stride) + first_block_table_offset * int64_t(params.k_row_stride) + (bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(first_page) * int64_t(params.v_batch_stride) + first_block_table_offset * int64_t(params.v_row_stride) + (bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = binfo.sum_s_q * int64_t(params.o_row_stride) + bidh * params.o_head_stride + m_block * kBlockM * params.o_row_stride;
row_offset_lse = bidh * params.total_q + binfo.sum_s_q;
}
// FP8 descale: per-head descale
int row_offset_q_descale = bidb * params.q_descale_batch_stride + bidh * params.q_descale_head_stride;
int row_offset_k_descale = bidb * params.k_descale_batch_stride + int(bidh / params.h_h_k_ratio) * params.k_descale_head_stride;
int row_offset_v_descale = bidb * params.v_descale_batch_stride + int(bidh / params.h_h_k_ratio) * params.v_descale_head_stride;
// 使用原始指针 (FP8 prefetch函数内部会调用 prepare_for_matrix_load)
Element_k* q_ptr = reinterpret_cast<Element_k*>(params.q_ptr) + row_offset_q;
Element_k* k_ptr = reinterpret_cast<Element_k*>(params.k_ptr) + row_offset_k;
Element_k* v_ptr = reinterpret_cast<Element_k*>(params.v_ptr) + row_offset_v;
ElementAccum* q_descale_ptr = reinterpret_cast<ElementAccum*>(params.q_descale_ptr);
ElementAccum* k_descale_ptr = reinterpret_cast<ElementAccum*>(params.k_descale_ptr);
ElementAccum* v_descale_ptr = reinterpret_cast<ElementAccum*>(params.v_descale_ptr);
ElementAccum q_descale = q_descale_ptr[row_offset_q_descale];
ElementAccum k_descale = k_descale_ptr[row_offset_k_descale];
ElementAccum qk_descale = q_descale * k_descale;
ElementAccum softmax_scale = params.scale_softmax * qk_descale;
ElementAccum softmax_scale_log2 = params.scale_softmax_log2 * qk_descale;
ElementAccum v_descale = v_descale_ptr[row_offset_v_descale];
ElementAccum* softmax_lse_ptr = reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr);
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_M, Element_k>(q_ptr, q_lds, warp_id, params.q_row_stride, max_seq_q_offset);
int lane_id = threadIdx.x & 63;
// 准备寄存器
ElementAccum scores_max[WARP_M / 16];
ElementAccum scores_sum[WARP_M / 16];
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16];
fp8_attention_initialize<kHeadDim, WARP_M, WARP_N, ElementAccum>(scores_max, scores_sum, acc_o);
// 从 lds 读取 q 的数据
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64];
load_q_from_lds_to_vgpr<kHeadDim, WARP_M, Element_k>(q_regs, q_lds, warp_id, lane_id);
// ======================================================== Mainloop ======================================================================
int n_masking_steps = 1;
if constexpr (Is_causal) {
const int causal_start_col = m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q;
const int first_mask_block = max(n_block_min, causal_start_col / kBlockN);
n_masking_steps = n_block_max - first_mask_block;
} else if constexpr (Is_local) {
n_masking_steps = min(n_block_max - n_block_min, ceil_div(kBlockM, kBlockN));
}
n_masking_steps = min(max(n_masking_steps, 1), n_block_max - n_block_min);
constexpr bool Assume_valid_rows = !Is_local;
// ======================================================== Prefetch 第一块 K ======================================================================
if (n_block_max > n_masking_steps) {
fp8_prefetch_k_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, binfo.actual_seqlen_k);
}
// ======================================================== 主循环:不需要 causal mask + Prefetch K ============================================================
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// QK gemm(K 数据已在上一轮 prefetch 到 LDS)
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// Prefetch V
fp8_prefetch_v_to_lds<false/*Is_even_MN*/, kBlockN, kHeadDim, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
if constexpr (Is_local) {
fp8_apply_local_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(
s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN,
params.window_size_left, params.window_size_right, lane_id);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDim, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// cvt
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// PV MMAC + Prefetch 下一块 K(paged KV)
const int next_n_block_loop = n_block_loop + 1;
const int block_table_idx_cur = n_block_loop * kBlockN / page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * page_block_size;
const int block_table_idx_next = next_n_block_loop * kBlockN / page_block_size;
const int block_table_offset_next = next_n_block_loop * kBlockN - block_table_idx_next * page_block_size;
const int64_t table_delta = int64_t(block_table[block_table_idx_next] - block_table[block_table_idx_cur]);
const int64_t offset_delta = int64_t(block_table_offset_next - block_table_offset_cur);
Element_k* k_ptr_next = k_ptr
+ table_delta * int64_t(params.k_batch_stride)
+ offset_delta * int64_t(params.k_row_stride);
const int max_seq_kv_offset_next = binfo.actual_seqlen_k - next_n_block_loop * kBlockN;
fp8_pv_gemm_and_prefetch_k_paged<false/*Is_even_MN*/, kHeadDim, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr_next, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset_next);
// 更新 K/V 指针
k_ptr = k_ptr_next;
v_ptr += table_delta * int64_t(params.v_batch_stride)
+ offset_delta * int64_t(params.v_row_stride);
}
// ======================================================== Masking 循环:需要 causal mask,不 Prefetch K ============================================================
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// 如果主循环没有 prefetch(n_block_max <= n_masking_steps),需要在这里 prefetch K
if (masking_step == 0 && n_block_max <= n_masking_steps) {
fp8_prefetch_k_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset);
}
// QK gemm
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// Prefetch V
fp8_prefetch_v_to_lds<false/*Is_even_MN*/, kBlockN, kHeadDim, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// Mask
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if constexpr (!Is_causal && !Is_local) {
fp8_apply_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, max_seq_kv_offset, 0, lane_id);
}
// Causal mask
if constexpr (Is_causal) {
fp8_apply_causal_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN, lane_id);
} else if constexpr (Is_local) {
fp8_apply_local_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(
s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN,
params.window_size_left, params.window_size_right, lane_id);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDim, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// cvt
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// PV MMAC(不 Prefetch K)
fp8_pv_gemm_and_prefetch_k<false/*PrefetchK*/, false/*Is_even_MN*/, kHeadDim, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset);
// 更新 K/V 指针
const int next_n_block_loop = min(n_block_max - 1, n_block_loop + 1);
const int64_t table_delta = int64_t(block_table[next_n_block_loop] - block_table[n_block_loop]);
k_ptr += table_delta * int64_t(params.k_batch_stride);
v_ptr += table_delta * int64_t(params.v_batch_stride);
}
// ========================================== rescale by scores_sum ==========================================
ElementAccum lse[WARP_M / 16];
if constexpr (Return_softmax) {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDim, WARP_M, WARP_N, true/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
} else {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDim, WARP_M, WARP_N, false/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
}
// ========================================== lse storation (varlen) ==========================================
if constexpr (Return_softmax) {
fp8_epilogue_store_lse<false/*Is_even_MN*/, WARP_M, ElementAccum>(
softmax_lse_ptr, scores_max, scores_sum, lse, row_offset_lse, binfo.actual_seqlen_q, m_block * kBlockM + warp_id * WARP_M, lane_id);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output<false/*Is_even_MN*/, kBlockM, kHeadDim, WARP_M, WARP_N, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, params.o_row_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout>
__global__ void __launch_bounds__(256, 1) flash_fp8_fwd_prefix_prefill_kernel_gfx938(Flash_fwd_params params) {
#if defined(__gfx938__) || defined(__gfx946__)
// LPT 调度:改变 blockIdx 到 m_block/bidh/bidb 的映射
// causal 模式:blockIdx.x = bidh, blockIdx.y = bidb, blockIdx.z 倒序 = m_block
// 非 causal 模式:blockIdx.x = m_block, blockIdx.y = bidh, blockIdx.z = bidb
constexpr bool Do_lpt = Is_causal;
const int bidh = Do_lpt ? blockIdx.x : blockIdx.y;
const int bidb = Do_lpt ? blockIdx.y : blockIdx.z;
int warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / 64);
int m_block = Do_lpt ? gridDim.z - 1 - blockIdx.z : blockIdx.x;
flash::compute_fp8_attn_prefix_prefill_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_K, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
#endif
}
} // namespace flash
File mode changed from 100644 to 100755
......@@ -176,23 +176,15 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
}
__builtin_amdgcn_sched_barrier(0);
} else { // 非 kHeaddim 128, 交给编译器后续的优化了
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (kHeadDimV/kBlockK) * ((WARP_M/32)*(kBlockK/32)); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
#endif
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(pk_zero);
}
}
}
......@@ -418,4 +410,445 @@ inline __device__ void compute_attn_splitkv_int8(const Params &params) {
flash::compute_attn_mha_1rowblock_splitkv_int8<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_K, Return_softmax, Has_alibi, Split, M_MMAC_COUNT, REUSE_KV_TIMES, Flash_fwd_params>(params, bidb, bidh, warp_id);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MLS-based FP8 Paged Attention, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_acco_reduce_compact_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
constexpr int kReduceBlockK = 32;
constexpr int kReduceRows = M_WARP_COUNT * M_MMAC_COUNT * 16;
const int q_seq_idx = lane_id & 15;
const int lane_dim_offset = (lane_id >> 4) * 4;
const int even_reuse_kv_times = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2 : ((seqlen_q + 1) / 2) * 2;
const bool is_valid_q_lane = q_seq_idx < even_reuse_kv_times;
#pragma unroll
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = (warp_id * kReduceRows + row_idx) * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset;
const int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT
+ k_idx * M_WARP_COUNT + warp_m_idx;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32;
}
}
}
}
__syncthreads();
if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = row_idx * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset + vec_idx;
ElementAccum acc_tmp = acc_o_lds[lds_offset];
#pragma unroll
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp += acc_o_lds[lds_offset + loop * kReduceRows * kReduceBlockK];
}
acc_o_lds[lds_offset] = acc_tmp;
}
}
}
}
}
}
}
__syncthreads();
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = row_idx * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset;
const int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT
+ k_idx * M_WARP_COUNT + warp_m_idx;
acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
}
__syncthreads();
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_kvcache_apply_mask_local_causal_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q,
const int ngroups, const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 8;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int logical_row = row_idx / ngroups;
const int logical_q = max_seqlen_q / ngroups;
const int col_idx_limit_left = max(0, logical_row + max_seqlen_k - logical_q - window_size_left);
const int col_idx_limit_right = min(max_seqlen_k, logical_row + max_seqlen_k - logical_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] =
(col_idx < col_idx_limit_left || col_idx > col_idx_limit_right)
? -INFINITY
: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template<int kHeadDim, int kBlockM, int WARP_M, int M_MMAC_COUNT, typename Element>
__forceinline__ __device__ void fp8_mha_prefetch_q_to_vgpr_gfx938(
vec4_uint q_addr,
Element* q_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset) {
static_assert(kHeadDim == 128 || kHeadDim == 256);
static_assert(WARP_M == 32);
vec4_uint q_srsrc;
q_srsrc[1] = q_addr[1];
q_srsrc[2] = query_seqlen_stride;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 128; ++k_loop) {
if (warp_id == min_tile_m) {
const int q_row_base = min_tile_m * 16;
const int valid_rows = max_seq_q_offset - q_row_base;
const int safe_q_row_base = valid_rows <= 0 ? 0 : q_row_base;
const int nm_filter = inline_min_max<0, 16>(16 - valid_rows);
q_srsrc[3] = valid_rows >= 16 ? 0 : (nm_filter << 8);
const int64_t row_offset_bytes = int64_t(safe_q_row_base) * int64_t(query_seqlen_stride) * sizeof(Element);
const int64_t dim_offset_bytes = int64_t(k_loop) * 128 * sizeof(Element);
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + row_offset_bytes + dim_offset_bytes);
const int lds_offset_bytes = (min_tile_m * (kHeadDim / 128) + k_loop) * 16 * 128 * sizeof(Element);
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
}
}
}
flash::wait_buffer_data_arrived<true/*sync*/>(0);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 128; ++k_loop) {
const int lds_offset_bytes = (min_tile_m * (kHeadDim / 128) + k_loop) * 16 * 128 * sizeof(Element);
const int q_lds_load_offset = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
DS_READ_MATRIX_64x16_B8(q_lds_load_offset, q_reg[min_tile_m][k_loop * 2 + 0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(q_lds_load_offset + 1024, q_reg[min_tile_m][k_loop * 2 + 1].i32x4, true/*transpose*/)
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_fp8_gfx938(const Params &params, const int bidb, const int bidh, const int warp_id) {
using Element = fp8_e4m3;
using ElementAccum = typename Kernel_traits::ElementAccum;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockN / WARP_N;
constexpr int kHeadDimVSplit = kHeadDimV / HEADDIM_V_SPLIT;
static_assert(kBlockK == 64);
static_assert(kHeadDim == 128 || kHeadDim == 256);
static_assert(kHeadDimVSplit == 128);
flash::SafeDecodeBlockInfo binfo;
binfo.set_params<Params, /*Is_Q_varlen=*/Is_Varlen, /*Is_K_Cumulative=*/false>(params, bidb);
int split_id = 0;
int original_actual_seqlen_k = binfo.actual_seqlen_k;
int partition_size = 0;
if constexpr (Split) {
split_id = blockIdx.y;
if constexpr (Is_Varlen) {
partition_size = splitkv_get_partitionsize_of_fix_numsplits(binfo.actual_seqlen_k, params.num_splits);
binfo.actual_seqlen_k = min(binfo.actual_seqlen_k - split_id * partition_size, partition_size);
} else {
partition_size = params.partition_size;
int num_splits = max(1, floor_div(binfo.actual_seqlen_k, partition_size));
binfo.actual_seqlen_k = (split_id == num_splits - 1)
? binfo.actual_seqlen_k - split_id * partition_size : partition_size;
binfo.actual_seqlen_k = (split_id >= num_splits) ? 0 : binfo.actual_seqlen_k;
if (split_id >= num_splits) return;
}
}
int block_x = blockIdx.x;
const int m_block = block_x / HEADDIM_V_SPLIT;
const int headdim_split_id = block_x & (HEADDIM_V_SPLIT - 1);
int ngroups = 1;
int actual_seqlen_q = binfo.actual_seqlen_q;
if constexpr (Is_Varlen) {
ngroups = params.ngroups;
actual_seqlen_q = binfo.actual_seqlen_q * ngroups;
}
if (m_block * kBlockM >= actual_seqlen_q || binfo.actual_seqlen_k <= 0) return;
extern __shared__ Element fp8_smem[];
constexpr int q_smem_bytes = STAGES * kBlockM * kBlockK * sizeof(Element);
constexpr int kv_smem_bytes = STAGES * kBlockK * WARP_N * sizeof(Element) * WARP_NUM;
constexpr int gemm_smem_bytes = q_smem_bytes > kv_smem_bytes ? q_smem_bytes : kv_smem_bytes;
Element* q_lds = reinterpret_cast<Element*>(fp8_smem);
Element* k_lds = reinterpret_cast<Element*>(fp8_smem);
Element* v_lds = k_lds;
ElementAccum* acc_o_lds = reinterpret_cast<ElementAccum*>(fp8_smem);
ElementAccum* max_lds = reinterpret_cast<ElementAccum*>(
reinterpret_cast<char*>(fp8_smem) + gemm_smem_bytes);
const int query_seqlen_stride = params.q_row_stride;
const int kcache_seqlen_stride = params.k_row_stride;
const int vcache_seqlen_stride = params.v_row_stride;
int n_block_min = 0;
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_local) {
const int q_row_start = m_block * kBlockM;
const int q_row_end = min(actual_seqlen_q, (m_block + 1) * kBlockM) - 1;
const int logical_q = Is_Varlen ? actual_seqlen_q / ngroups : actual_seqlen_q;
const int logical_row_start = Is_Varlen ? q_row_start / ngroups : q_row_start;
const int logical_row_end = Is_Varlen ? q_row_end / ngroups : q_row_end;
const int split_seqlen_start = Split ? split_id * partition_size : 0;
const int local_left = max(0, logical_row_start + original_actual_seqlen_k - logical_q - params.window_size_left);
const int local_right = min(original_actual_seqlen_k, logical_row_end + original_actual_seqlen_k - logical_q + params.window_size_right + 1);
const int split_local_left = local_left - split_seqlen_start;
const int split_local_right = local_right - split_seqlen_start;
const int n_block_count = n_block_max;
const int raw_n_block_min = max(0, split_local_left / kBlockN);
const int raw_n_block_max = ceil_div(max(0, split_local_right), kBlockN);
n_block_min = min(max(raw_n_block_min, 0), max(0, n_block_count - 1));
n_block_max = min(max(raw_n_block_max, n_block_min + 1), n_block_count);
}
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
const int this_split_seqlen_start = Split ? split_id * partition_size : 0;
block_table = block_table + (Split ? ceil_div(this_split_seqlen_start, page_block_size) : 0);
const int block_table_idx = n_block_min * kBlockN / page_block_size;
const int block_table_offset = n_block_min * kBlockN - block_table_idx * page_block_size;
const int64_t row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride)
+ block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const int64_t row_offset_v = int64_t(block_table[block_table_idx]) * int64_t(params.v_batch_stride)
+ block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int64_t row_offset_q = Is_Varlen
? binfo.sum_s_q * ngroups * int64_t(query_seqlen_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(query_seqlen_stride)
: bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(query_seqlen_stride);
auto q_addr = prepare_for_buffer_load<kHeadDim, Element, false>(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q);
auto k_addr = prepare_for_buffer_load<kHeadDim, Element, false>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_addr = prepare_for_buffer_load<kHeadDimV, Element, false>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v + headdim_split_id * kHeadDimVSplit);
const int q_descale_head = bidh;
const int kv_descale_head = bidh / params.h_h_k_ratio;
const ElementAccum q_descale = params.q_descale_ptr[bidb * params.q_descale_batch_stride + q_descale_head * params.q_descale_head_stride];
const ElementAccum k_descale = params.k_descale_ptr[bidb * params.k_descale_batch_stride + kv_descale_head * params.k_descale_head_stride];
const ElementAccum v_descale = params.v_descale_ptr[bidb * params.v_descale_batch_stride + kv_descale_head * params.v_descale_head_stride];
__float2 qk_descale = {q_descale * k_descale, q_descale * k_descale};
int row_offset_lse;
ElementAccum *scores_sum_ptr = nullptr;
ElementAccum *scores_max_ptr = nullptr;
ElementAccum *softmax_lse_ptr = nullptr;
if constexpr (Split) {
int row_offset_scores_split;
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.h * ngroups * params.total_q);
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) + row_offset_lse + row_offset_scores_split;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.b * params.h * params.seqlen_q);
scores_sum_ptr = reinterpret_cast<ElementAccum*>(params.scores_sum_ptr) + row_offset_lse + row_offset_scores_split;
scores_max_ptr = reinterpret_cast<ElementAccum*>(params.scores_max_ptr) + row_offset_lse + row_offset_scores_split;
}
} else {
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
}
}
constexpr int M_WARP_COUNT = WARP_M / 32;
constexpr int K_WARP_COUNT = kBlockK / 32;
constexpr int N_WARP_COUNT = WARP_N / 32;
constexpr int K_LOOP_COUNT = kHeadDimVSplit / kBlockK;
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT];
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT];
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4];
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64];
attention_initialize<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(scores_max, scores_sum, acc_o);
fp8_mha_prefetch_q_to_vgpr_gfx938<kHeadDim, kBlockM, WARP_M, M_MMAC_COUNT, Element>(
q_addr, q_lds, q_reg, warp_id, query_seqlen_stride, actual_seqlen_q - m_block * kBlockM);
int n_block_loop = n_block_min;
constexpr bool PrefetchK = true;
if constexpr (PrefetchK) {
int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_min * kBlockN;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
}
for (; n_block_loop < n_block_max; ++n_block_loop) {
const int warp_offset_in_seqkv = n_block_loop * kBlockN + warp_id * WARP_N;
const int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_loop * kBlockN;
constexpr bool PrefetchVInQK = (kHeadDim == 128 && K_LOOP_COUNT == 2);
if constexpr (!PrefetchK) {
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
}
vec4_Accum<ElementAccum> s_reg[M_WARP_COUNT * N_WARP_COUNT][4];
fp8_kvcache_qk_gemm_gfx938<PrefetchVInQK, K_LOOP_COUNT, kHeadDim, kBlockK, WARP_M, WARP_N, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
k_addr, v_addr, k_lds, v_lds, q_reg, s_reg, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit);
if constexpr (!PrefetchVInQK) {
fp8_kvcache_prefetch_v_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element>(
v_addr, v_lds, warp_id, vcache_seqlen_stride, warp_seqkv_limit);
}
fp8_kvcache_apply_descale_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(s_reg, qk_descale);
if constexpr (Is_causal) {
if constexpr (Is_Varlen) {
if constexpr (Is_local) {
fp8_kvcache_apply_mask_local_causal_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, ngroups, params.window_size_left, params.window_size_right);
} else {
kvcache_apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, ngroups);
}
} else {
kvcache_apply_mask_causal_gfx938_mtp<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, params.mtp, params.layout);
}
} else {
kvcache_apply_mask_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(s_reg, warp_seqkv_limit, warp_id * WARP_N);
}
mla_softmax_rescale_o<Is_causal || Is_local, ElementAccum, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, N_WARP_COUNT, WARP_NUM, M_MMAC_COUNT>(
s_reg, scores_max, scores_sum, acc_o, max_lds, warp_id, params.scale_softmax_log2);
union_vec32_fp8 p_reg[M_MMAC_COUNT];
fp8_kvcache_cvt_f32_to_fp8_gfx938<M_MMAC_COUNT, Element, ElementAccum>(p_reg, s_reg);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
const int64_t k_addr_offset = (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938<PrefetchK, K_LOOP_COUNT, kBlockK, kBlockN, M_WARP_COUNT, K_WARP_COUNT, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
v_addr, k_addr, v_lds, k_lds, p_reg, acc_o, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit, k_addr_offset);
*(int64_t*)&v_addr += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * int64_t(params.v_row_stride)) * sizeof(Element);
}
if constexpr (PrefetchK) {
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
const int thread_id = threadIdx.x;
const int lane_id = thread_id & 63;
if constexpr (WARP_NUM > 1) {
fp8_kvcache_acco_reduce_compact_gfx938<REUSE_KV_TIMES, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, ElementAccum>(
acc_o, acc_o_lds, params.seqlen_q, warp_id, lane_id);
}
fp8_kvcache_epilogue_rescale_acco_gfx938<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(acc_o, scores_sum, v_descale);
if constexpr (Is_Varlen) {
kvcache_epilogue_store_softmax_lse<Is_Varlen, true, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, softmax_lse_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM, params.total_q, params.ngroups);
const int64_t row_offset_o = binfo.sum_s_q * ngroups * int64_t(params.o_row_stride) + bidh * ngroups * params.o_head_stride + headdim_split_id * kHeadDimVSplit + m_block * kBlockM * int64_t(params.o_row_stride);
kvcache_varlen_epilogue_store_output_gfx938<Params, kHeadDimV, kHeadDimVSplit, Split, SplitkvAccumType, ElementAccum, kBlockM, kBlockK, WARP_NUM, K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT>(
acc_o, params, row_offset_o, actual_seqlen_q - m_block * kBlockM, bidb, bidh, m_block, split_id, headdim_split_id, warp_id, lane_id);
} else {
kvcache_epilogue_store_max_sum<Split, true/*Is_16x32*/, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, scores_max_ptr, scores_sum_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM);
kvcache_epilogue_store_output_gfx938<Params, kHeadDimV, kHeadDimVSplit, true/*alt*/, Split, SplitkvAccumType, ElementAccum, kBlockM, kBlockK, WARP_NUM, K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT>(
acc_o, params, bidb, bidh, m_block, split_id, headdim_split_id, warp_id, lane_id);
}
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_fp8_gfx938(const Params &params) {
#if defined(__gfx938__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
// The block index for the batch.
const int bidb = Split ? blockIdx.z / params.h : blockIdx.z;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_fp8_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
#endif
}
} // namespace flash
File mode changed from 100644 to 100755
......@@ -486,3 +486,38 @@ void run_int8_flash_fwd_prefix_prefill(Flash_fwd_params &params, hipStream_t str
run_int8_flash_fwd_prefix_prefill_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T, Float16, int8_t>, Is_causal>(params, stream);
});
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_fp8_fwd_prefix_prefill_launcher_gfx938(Flash_fwd_params &params, hipStream_t stream) {
size_t smem_size = 16 * 1024;
int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid = Is_causal ? dim3(params.h, params.b, num_m_block)
: dim3(num_m_block, params.h, params.b);
constexpr bool Has_Alibi = false;
const bool is_local = !Is_causal && params.window_size_left > 0 && params.window_size_right >= 0;
BOOL_SWITCH(params.softmax_lse_ptr != nullptr, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(is_local, IsLocalConst, [&] {
LAYOUT_SWITCH(params.layout, [&]{
auto kernel = &flash_fp8_fwd_prefix_prefill_kernel_gfx938<Kernel_traits, true/*Is_training*/, false/*Is_dropout*/, Is_causal, IsLocalConst, true/*Is_even_K*/, ReturnSoftmaxConst, Has_Alibi, Layout>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
}
template<typename T, int Headdim, int HeaddimV>
void run_fp8_flash_fwd_prefix_prefill(Flash_fwd_params &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938<
Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T, Float16, fp8_e4m3>, Is_causal>(params, stream);
});
} else {
printf("\x1b[31mfp8 prefix_prefill is not supported in this arch!\033[0m\n");
}
}
File mode changed from 100644 to 100755
......@@ -35,6 +35,12 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_gfx938_kernel(Params
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_fp8_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_fp8_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, const bool Tail, typename Params>
void run_splitkv_reduce(Params &params, hipStream_t stream) {
......@@ -245,7 +251,9 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params &params, hipStream_t strea
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.window_size_left > 0 and params.window_size_right >= 0) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, false, true/*Is_local*/, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, 0>;
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, true/*Is_local*/, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, 0>;
});
});
} else if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
......@@ -325,7 +333,7 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream)
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
constexpr int HEADDIM_V_SPLIT = Kernel_traits::kHeadDimV == 256 ? 2 : 1;
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.mtp == 1) {
......@@ -362,12 +370,106 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream)
}
template<typename Kernel_traits>
void run_fp8_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
constexpr int kReduceBlockK = 32;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_for_acc = Kernel_traits::kBlockM * WARP_NUM * kReduceBlockK * sizeof(float);
const size_t q_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockM * Kernel_traits::kBlockK * sizeof(Float8_e4m3_t);
const size_t k_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockK * Kernel_traits::kWaveN * sizeof(Float8_e4m3_t) * WARP_NUM;
const size_t v_smem_size = k_smem_size;
const size_t smem_for_gemm = std::max(q_smem_size, std::max(k_smem_size, v_smem_size));
const size_t required_smem_size = std::max(smem_for_acc, smem_for_gemm + smem_for_max);
const size_t smem_size = size_t(std::max<size_t>(17 * 1024, required_smem_size));
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld\n",
smem_for_max, smem_for_acc, q_smem_size, k_smem_size, v_smem_size, smem_for_gemm, required_smem_size, smem_size);
}
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = Kernel_traits::kHeadDimV == 256 ? 2 : 1;
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.window_size_left > 0 && params.window_size_right >= 0) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, true/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
} else if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH(params.seqlen_q, [&] {
constexpr bool Is_local = false;
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
} else {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
}
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream) {
// decide whether commonly used headdims
const bool is_commonly_used = params.d % 64 == 0 and params.d_value % 64 == 0/*prefetch 2 32x32 blocks along headdim*/;
// For latest archs, mls can be applied for headdim 128
if ((getArch() >= 938) and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used) {
// For latest archs, MLS can be applied for the common decode head dims.
constexpr bool use_gfx938_mls =
(Headdim == 128 and HeaddimV == 128) or
(Headdim == 256 and HeaddimV == 256);
if constexpr (use_gfx938_mls) {
const bool is_local = params.window_size_left > 0 && params.window_size_right >= 0;
const bool use_mls_mask = params.is_e4m3 ? true : params.is_causal;
if ((getArch() >= 938) and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used and use_mls_mask) {
if (params.is_e4m3) {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
run_fp8_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 64, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_fp8_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 64, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
#endif
} else {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
......@@ -380,8 +482,11 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream)
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
#endif
}
return;
}
}
// For MHA-fma, headdim = 128
else if (params.seqlen_q == 1 and !params.seqlenq_ngroups_swapped and Headdim == 128 and HeaddimV == 128 and std::getenv("PA_USE_FMA") != nullptr) {
if (params.seqlen_q == 1 and !params.seqlenq_ngroups_swapped and Headdim == 128 and HeaddimV == 128 and std::getenv("PA_USE_FMA") != nullptr) {
constexpr int kBlockN = 128;
run_flash_splitkv_fwd_mha<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32/*kBlockM*/, kBlockN, 32/*kBlockK*/, 32, 32, 2/*STAGES*/, false, false, T, float> >(params, stream);
}
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment