Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
......@@ -428,8 +428,8 @@ __forceinline__ __device__ void mla_prefix_prefill_combine_s_reg_of_2waves(vec4_
: ((warp_id & 1) ? warp_id - 1: warp_id + 1);
int lds_load_offset = n_loop * WARP_NUM * (64 * 4) + warp_id_symmetry * 64 * 4 + lane_id * 4;
vec4_Accum<ElementAccum> symmetry_data = *(vec4_Accum<ElementAccum>*)(s_reg_lds + lds_load_offset);
s_reg[m_idx][n_loop].u64[0] = hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[0], symmetry_data.u64[0]);
s_reg[m_idx][n_loop].u64[1] = hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[1], symmetry_data.u64[1]);
s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[0], symmetry_data.u64[0]);
s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[1], symmetry_data.u64[1]);
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
......@@ -471,9 +471,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scale_softmax_log2_pair[1] = scale_softmax_log2;
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
s_reg[m_idx][n_loop].u64[0] = hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[0], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].u64[1] = hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[1], scale_softmax_log2_pair, max_scaled);
asm volatile("s_nop 0" ::: "memory");
s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[0], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[1], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].f32[0] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[0]);
s_reg[m_idx][n_loop].f32[1] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[1]);
s_reg[m_idx][n_loop].f32[2] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[2]);
......@@ -489,8 +488,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum_pair[1] = 0;
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
scores_sum_pair = hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[0]);
scores_sum_pair = hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[1]);
scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[0]);
scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[1]);
}
scores_sum_cur[m_idx] = scores_sum_pair[0] + scores_sum_pair[1];
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor(scores_sum_cur[m_idx], 32);
......@@ -505,8 +504,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum[m_idx] *= scores_scale[0];
#pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], scores_scale);
acc_o[m_idx][pv_tile].u64[1] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], scores_scale);
acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], scores_scale);
acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], scores_scale);
}
}
// update max/sum
......@@ -607,7 +606,7 @@ __forceinline__ __device__ void mla_prefix_prefill_cvt_dtype(
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
p_reg[m_idx][n_loop].f16x2[vec_idx] = DownCastPair<ElementAccum, Element>(s_reg[m_idx][n_loop].f32x2[vec_idx]);
......@@ -932,8 +931,8 @@ __forceinline__ __device__ void mla_prefix_prefill_rescale_acc_o(
inv_sum[1] = inv_sum[0];
#pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit / 16; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], inv_sum);
acc_o[m_idx][pv_tile].u64[1] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], inv_sum);
acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], inv_sum);
acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], inv_sum);
}
}
}
......@@ -962,7 +961,7 @@ __forceinline__ __device__ void mla_prefix_prefill_store_output(
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
vec2_Element<Element> data;
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
data[mmac_id] = DownCast<ElementAccum, Element, true>(acc_o[m_idx][v_tile * 2 + mmac_id].f32[vec_idx]);
......
......@@ -74,7 +74,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4;
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
......@@ -97,7 +97,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
......@@ -220,7 +220,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4;
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
......@@ -243,7 +243,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
......
......@@ -27,7 +27,7 @@ __forceinline__ __device__ void mla_qk_gemm_prefetch_v_tile16x32(
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
#if defined(__gfx936__) || defined(__gfx938__) // >= bmz
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
......
......@@ -13,7 +13,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_tile16x32(
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset=-1) {
#if defined(__gfx928__)
#if defined(__gfx928__) || defined(__gfx92a__)
constexpr int Q_LOAD_REQUESTS = (kBlockM * kBlockK >> 1/*16x32 tile*/) * M_MMAC_COUNT / (4 * 32 * WARP_NUM);
constexpr int SEQUENCE_READ = M_MMAC_COUNT;
constexpr int READ_ONCE_LINES = 4;
......@@ -99,7 +99,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_tile16x32(
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
#elif defined(__gfx936__) || defined(__gfx938__)
#elif defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int lane_id = threadIdx.x & 63;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
......@@ -143,7 +143,7 @@ __forceinline__ __device__ void mla_prefetch_k_to_lds_tile16x32(
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
#if defined(__gfx936__) || defined(__gfx938__) // >= bmz
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
......
......@@ -117,7 +117,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
if(zero_init == true) {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
......@@ -125,7 +125,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -151,7 +151,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
} else {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
......@@ -159,7 +159,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -258,15 +258,14 @@ inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WA
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
......@@ -347,10 +346,10 @@ inline __device__ void mla_softmax_rescale_o(
int loop_id = (pv_n_loop * K_WARP_COUNT + ni) * M_WARP_COUNT + mi;
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
......@@ -401,8 +400,8 @@ inline __device__ void mla_softmax_rescale_o(
#pragma unroll
for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
......@@ -425,8 +424,8 @@ inline __device__ void mla_softmax_rescale_o(
}
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......@@ -454,7 +453,7 @@ inline __device__ void mla_convert_pk_type(union_vec2_f16x2<Element> p_reg[M_WAR
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) || defined(__gfx92a__)
p_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
......
......@@ -79,8 +79,6 @@ union union_vec_fp32 {
union union_vec4_uint {
unsigned long long u64[2]; // 128 bits
uint4 u32;
vec4_int i32;
vec4_uint v32;
uint8_t u8[16];
};
......@@ -261,3 +259,4 @@ __forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_
// return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)};
#endif
}
......@@ -13,8 +13,10 @@
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccumType, int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, int kHeadDim_, int STAGES_, bool USE_BSHD_LAYOUT, typename Params>
__global__ void __launch_bounds__(256,1) flash_bwd_dot_do_o_kernel(Params params) {
#if defined(__gfx938__)
compute_dot_do_o_gfx938<true, Is_even_MN, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, kHeadDim_, STAGES_, USE_BSHD_LAYOUT>(params);
#if defined(__gfx946__)
// compute_dot_do_o_gfx946<true, Is_even_MN, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, kHeadDim_, STAGES_, USE_BSHD_LAYOUT>(params);
#elif defined(__gfx938__)
compute_dot_do_o<true, Is_even_MN, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, kHeadDim_, STAGES_, USE_BSHD_LAYOUT>(params);
#else
compute_dot_do_o<true, Is_even_MN, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, kHeadDim_, STAGES_, USE_BSHD_LAYOUT>(params);
#endif
......@@ -27,8 +29,10 @@ __global__ void __launch_bounds__(256,1) flash_attention_dv_dk_bwd_kernel(Param
const int bidb = bidbh / params.h;
const int bidh = bidbh % params.h;
const int n_block = blockIdx.y;
#if defined(__gfx938__)
compute_dk_dv_1colblock_gfx938<Element, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, USE_BSHD_LAYOUT>(params, bidb, bidh, n_block);
#if defined(__gfx946__)
// compute_dk_dv_1colblock_gfx946<Element, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, USE_BSHD_LAYOUT>(params, bidb, bidh, n_block);
#elif defined(__gfx938__)
compute_dk_dv_1colblock<Element, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, USE_BSHD_LAYOUT>(params, bidb, bidh, n_block);
#else
compute_dk_dv_1colblock<Element, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, USE_BSHD_LAYOUT>(params, bidb, bidh, n_block);
#endif
......@@ -42,8 +46,10 @@ __global__ void __launch_bounds__(256,1) flash_attention_dq_bwd_kernel(Params p
const int m_actual_block = (params.seqlen_q + kBlockM_ - 1) / kBlockM_;
const int m_block = m_actual_block - 1 - blockIdx.y;
#if defined(__gfx938__)
compute_dq_1colblock_gfx938<ElementType, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, STAGES, USE_BSHD_LAYOUT>(params, bidb, bidh, m_block);
#if defined(__gfx946__)
// compute_dq_1colblock_gfx946<ElementType, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, STAGES, USE_BSHD_LAYOUT>(params, bidb, bidh, m_block);
#elif defined(__gfx938__)
compute_dq_1colblock<ElementType, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, STAGES, USE_BSHD_LAYOUT>(params, bidb, bidh, m_block);
#else
compute_dq_1colblock<ElementType, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, STAGES, USE_BSHD_LAYOUT>(params, bidb, bidh, m_block);
#endif
......@@ -80,7 +86,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params) {
// std::cout<<"USE_BSHD_LAYOUT="<<USE_BSHD_LAYOUT<<std::endl;
hipStream_t stream = NULL;
const bool is_even_MN = ((params.seqlen_k % kBlockN_) == 0) && ((params.seqlen_q) % kBlockM_ == 0) && params.cu_seqlens_q == nullptr;
// Even-MN must be computed with the same tile shape as each launched kernel.
const bool is_even_MN_dot = ((params.seqlen_k % kBlockN_) == 0) && ((params.seqlen_q % kBlockM_) == 0) && params.cu_seqlens_q == nullptr;
//is_even_K指headdim是否是32的整数倍,否则需要进行边界判断
const bool is_even_K = params.d == K;
......@@ -109,7 +116,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params) {
// flash_attention_bwd: 34.9 ms
// flash_bwd_convert_dq_kernel: 0.9 ms
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_MN_dot, IsEvenMNConst, [&] {
flash_bwd_dot_do_o_kernel<true,IsEvenMNConst, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, K_v, STAGES, USE_BSHD_LAYOUT>
<<<grid_m, kMThreads, 0, stream>>>(params);
});
......@@ -148,6 +155,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params) {
constexpr int dk_dv_kBlockK = 32;
constexpr int dk_dv_WARP_M = 32;
constexpr int dk_dv_WARP_N = 32;
const bool is_even_MN_dk_dv = ((params.seqlen_k % dk_dv_kBlockN) == 0) && ((params.seqlen_q % dk_dv_kBlockM) == 0) && params.cu_seqlens_q == nullptr;
dim3 dimBlock;
int maxBlockThreads = 512;
dimBlock.x = min((dk_dv_kBlockN)/(dk_dv_WARP_N)*64, maxBlockThreads);
......@@ -164,7 +172,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params) {
// dim3 grid_n(gridDimx, params.h, params.b);
dim3 grid_n(params.se_balance_cnt, gridDimx, (params.h * params.b/params.se_balance_cnt));
// printf("flash_attention_dv_dk_bwd_kernel : grid(%d, %d, %d) | block(%d, %d, %d)\n", grid_n.x, grid_n.y, grid_n.z, dimBlock.x, dimBlock.y, dimBlock.z);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_MN_dk_dv, IsEvenMNConst, [&] {
flash_attention_dv_dk_bwd_kernel<Element, float, Is_dropout, Is_causal, Is_local, IsEvenMNConst, true, Is_first, Is_last, Seq_parallel, dk_dv_kBlockM, dk_dv_kBlockN, K, K_v, dk_dv_kBlockK, dk_dv_WARP_M, dk_dv_WARP_N, USE_BSHD_LAYOUT>
<<<grid_n, dimBlock, sharedMemSize, stream>>>(params);
});
......@@ -176,12 +184,13 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params) {
constexpr int dq_kBlockK = 32;
constexpr int dq_WARP_M = 32;
constexpr int dq_WARP_N = 32;
const bool is_even_MN_dq = ((params.seqlen_k % dq_kBlockN) == 0) && ((params.seqlen_q % dq_kBlockM) == 0) && params.cu_seqlens_q == nullptr;
int dq_kMThreads = (dq_kBlockM + dq_WARP_M-1)/dq_WARP_M * 64;
const int num_m_block_dq = (params.seqlen_q + dq_kBlockM - 1) / dq_kBlockM;
// dim3 grid_m(num_m_block_dq, params.h, params.b);
dim3 grid_m(params.se_balance_cnt, num_m_block_dq, (params.h * params.b/params.se_balance_cnt));
// printf("flash_attention_dq_bwd_kernel : grid(%d, %d, %d) | block(%d, %d, %d)\n", grid_m.x, grid_m.y, grid_m.z, dq_kMThreads, 1, 1);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_MN_dq, IsEvenMNConst, [&] {
flash_attention_dq_bwd_kernel<Element, float, Is_dropout, Is_causal, Is_local, IsEvenMNConst, true, Is_first, Is_last, Seq_parallel, dq_kBlockM, dq_kBlockN, K, K_v, dq_kBlockK, dq_WARP_M, dq_WARP_N, 2, USE_BSHD_LAYOUT>
<<<grid_m, dq_kMThreads, sharedMemSize, stream>>>(params);
});
......
......@@ -247,6 +247,11 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimV, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, Is_dropout && Is_training, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
int lane_id = threadIdx.x & 63;
......@@ -288,7 +293,7 @@ inline __device__ void compute_attn(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &params, const int bidb, const int __bidh, const int m_block, const int WARP_ID) {
inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &params, const int bidb, const int __bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
......@@ -330,17 +335,17 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
// 计算数据跨度
int seqlen_q_stride = (Layout == 1) ? params.q_row_stride: params.q_row_stride;
int seqlen_k_stride = (Layout == 1) ? params.k_row_stride: params.k_row_stride;
int seqlen_v_stride = (Layout == 1) ? params.v_row_stride: params.v_row_stride;
int seqlen_o_stride = (Layout == 1) ? params.o_row_stride: params.o_row_stride;
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_o_stride = params.o_row_stride;
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int64_t row_offset_lse;
// 获取页表信息
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = n_block_min;
const int block_table_offset = 0;
const int block_table_idx = n_block_min * kBlockN / page_block_size;
const int block_table_offset = n_block_min * kBlockN - block_table_idx * page_block_size;
if constexpr (Layout == 1) { /*bshd layout, lse is num_heads, total_q*/
row_offset_q = (binfo.sum_s_q + m_block * kBlockM) * int64_t(seqlen_q_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
......@@ -361,9 +366,9 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
auto gV = prepare_for_buffer_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
// attention 变体: Alibi
float gAlibi;
float g_alibi;
if constexpr (Has_alibi) {
gAlibi = reinterpret_cast<ElementAccum*>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
g_alibi = reinterpret_cast<ElementAccum*>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
}
// attention 插件: Dropout
......@@ -372,9 +377,9 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
union_vec2_uint warp_idx_for_dropout;
if constexpr (Is_dropout) {
rand_seed = params.rand_seed;
rand_offset = params.rand_offset + ((bidb * params.h + bidh) << 6) + threadIdx.x & 63; /* 参考官方写法 offset(offset + (bid * nheads + hid) * 32 + tid % 32) */
p_dropout_in_8bits_value = params.p_dropout_in_uint8_t & 0xffffffff; /*hcu 不支持 16bit 和 8bit 的比较指令*/
warp_idx_for_dropout.u32.x = 1 * m_block * (kBlockM / 32)/*前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/ + WARP_ID/*当前 block 内的 warp id*/;
rand_offset = params.rand_offset + ((bidb * params.h + bidh) << 6) + threadIdx.x & 63;
p_dropout_in_8bits_value = params.p_dropout_in_uint8_t & 0xffffffff;
warp_idx_for_dropout.u32.x = 1 * m_block * (kBlockM / 32) + warp_id;
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might exit early and no one saves the rng states.
if (m_block == 0 and bidb == 0 and bidh == 0 and threadIdx.x == 0) {
params.rng_state[0] = rand_seed;
......@@ -383,21 +388,17 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
// 预取 Q 的数据到寄存器
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4]; // ds_read mini size is 32 * 32,2 is seq, 4 is head dim
Is_even_MN
? prefetch_q_to_vgpr<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(gQ, q_lds, q_reg, WARP_ID, seqlen_q_stride)
: prefetch_q_to_vgpr<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(gQ, q_lds, q_reg, WARP_ID, seqlen_q_stride, (binfo.actual_seqlen_q - m_block * kBlockM));
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4];
prefetch_q_to_vgpr<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(gQ, q_lds, q_reg, warp_id, seqlen_q_stride, binfo.actual_seqlen_q - m_block * kBlockM);
/***************************************************************************************************************************/
/***************************************************************************************************************************/
vec2_Accum<ElementAccum> scores_max[WARP_M / 32];
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32];
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4];
attention_initialize<kHeadDimV / kBlockK, WARP_M / 32, kBlockK / 32, 2/*M_MMAC_COUNT*/, ElementAccum>(scores_max, scores_sum, acc_o);
/***************************************************************************************************************************/
/***************************************************************************************************************************/
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr bool PREFETCH_K = false;
constexpr bool PREFETCH_K = false; // true;
constexpr bool Aggressive = (kHeadDim == 128 or kHeadDim == 64);
auto QK_GEMM_FUNC = Aggressive
? &qk_gemm_prefetch_v_headdim128<kHeadDim, kHeadDimV, kBlockM, kBlockN, kBlockK, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN>
......@@ -406,47 +407,36 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
auto PV_GEMM_FUNC = Aggressive
? &pv_gemm_prefetch_k_headdim128<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>
: &pv_gemm_prefetch_k<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
// mask 循环中不需要做 prefetch K, 因此 prefetch K 固定为 false
auto PV_GEMM_FUNC_IN_MASK = Aggressive
? &pv_gemm_prefetch_k_headdim128<false, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>
: &pv_gemm_prefetch_k<false, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
if constexpr (PREFETCH_K) {
prefetch_k_to_lds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(gK, k_lds, warp_id, seqlen_k_stride, binfo.actual_seqlen_k - n_block_min * kBlockN);
}
// constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : flash::ceil_div(kBlockM, kBlockN);
// These are the iterations where we don't need masking on S
for (int n_block_loop = n_block_min; n_block_loop < n_block_max/*n_block_max - n_masking_steps*/; ++n_block_loop) {
const int seqlen_kv_limit = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// c mini tile is 32 * 32
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / WARP_N)][4];
if constexpr (STAGES > 1) {
Is_even_MN
? prefetch_k_to_lds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(gK, k_lds, WARP_ID, seqlen_k_stride)
: prefetch_k_to_lds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(gK, k_lds, WARP_ID, seqlen_k_stride, seqlen_kv_limit);
if constexpr (not PREFETCH_K) {
prefetch_k_to_lds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(gK, k_lds, warp_id, seqlen_k_stride, seqlen_kv_limit);
}
Is_even_MN
? QK_GEMM_FUNC(gQ, gK, gV, q_lds, k_lds, v_lds, q_reg, s_reg, WARP_ID, seqlen_k_stride, seqlen_v_stride, 0)
: QK_GEMM_FUNC(gQ, gK, gV, q_lds, k_lds, v_lds, q_reg, s_reg, WARP_ID, seqlen_k_stride, seqlen_v_stride, seqlen_kv_limit);
QK_GEMM_FUNC(gQ, gK, gV, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_v_stride, seqlen_kv_limit);
if constexpr (Has_alibi) {
apply_alibi<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + WARP_ID * WARP_M, binfo.actual_seqlen_q, gAlibi);
apply_alibi<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, binfo.actual_seqlen_q, g_alibi);
}
if constexpr (!Is_causal && !Is_local) {
if constexpr (!Is_even_MN) { apply_mask<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, seqlen_kv_limit); }
} else {
if constexpr (Is_local) {
apply_mask_local<Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + WARP_ID * WARP_M,
binfo.actual_seqlen_q, params.window_size_left,
params.window_size_right);
apply_mask_local<Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
} else if constexpr (Is_causal) {
apply_mask_causal<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + WARP_ID * WARP_M,
binfo.actual_seqlen_q);
}
apply_mask_causal<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, n_block_loop * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, binfo.actual_seqlen_q);
}
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
......@@ -457,12 +447,11 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
// convertType: float2half
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, true/*IsInference*/>(p_reg, s_reg);
Is_even_MN
? PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, WARP_ID, seqlen_k_stride, seqlen_v_stride, 0)
: PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, WARP_ID, seqlen_k_stride, seqlen_v_stride, seqlen_kv_limit);
if constexpr (not PREFETCH_K) {
PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, seqlen_kv_limit);
}
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
......@@ -472,21 +461,31 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
const int offset_diff = block_table_offset_next - block_table_offset_cur;
*(int64_t*)&gK += (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
if constexpr (PREFETCH_K) {
PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, seqlen_kv_limit);
}
*(int64_t*)&gV += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * int64_t(params.v_row_stride)) * sizeof(Element);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimV, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, false/*Is_dropout*/, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) {
fwd_epilogue_store_lse<WARP_M, Is_even_MN, SplitD, false/*Is_Interleaved*/, ElementAccum>(lse, params.softmax_lse_ptr, row_offset_lse, WARP_ID, lane_id, 0, binfo.actual_seqlen_q - m_block * kBlockM);
fwd_epilogue_store_lse<WARP_M, Is_even_MN, SplitD, false/*Is_Interleaved*/, ElementAccum>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, binfo.actual_seqlen_q - m_block * kBlockM);
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, false/*Is_Interleaves*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, WARP_ID, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, false/*Is_Interleaves*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
......@@ -674,6 +673,11 @@ inline __device__ void compute_attn_mha_padding_mask_1rowblock(const Params &par
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimV, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, false/*Is_dropout && Is_training*/, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
int lane_id = threadIdx.x & 63;
......@@ -840,6 +844,11 @@ inline __device__ void compute_attn_mha_attn_mask_1rowblock(const Params &params
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimV, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, false/*Is_dropout && Is_training*/, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
int lane_id = threadIdx.x & 63;
......@@ -1107,6 +1116,11 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params &params, c
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimPVCompute, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimPVCompute, Is_dropout && Is_training, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
......@@ -1123,7 +1137,7 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params &params, c
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_attn_gfx938(const Params &params) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
constexpr bool Do_lpt = Is_causal;
const int bidh = Do_lpt ? blockIdx.x: blockIdx.y;
......@@ -1194,8 +1208,8 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
// 获取页表信息
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = n_block_min;
const int block_table_offset = 0;
const int block_table_idx = n_block_min * kBlockN / page_block_size;
const int block_table_offset = n_block_min * kBlockN - block_table_idx * page_block_size;
if constexpr (Layout == 1) { /*bshd layout, lse is num_heads, total_q*/
row_offset_q = (binfo.sum_s_q + m_block * kBlockM) * int64_t(seqlen_q_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
......@@ -1351,10 +1365,12 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
}
// Attention: mask, causal mask, local mask
if constexpr (Is_local) {
apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q);
} else if constexpr (Is_causal) {
if constexpr (!Is_causal && !Is_local) {
apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, binfo.actual_seqlen_k, warp_offset_in_seqkv);
} else if constexpr (Is_local) {
apply_mask_local_gfx938</*HasWSLeft=*/Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
} else if constexpr (Is_causal) {
apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
......@@ -1388,6 +1404,11 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
if (params.s_aux_ptr != nullptr) {
const float sink_value = reinterpret_cast<const float *>(params.s_aux_ptr)[bidh];
fwd_apply_attention_sink<WARP_M, kBlockK, kHeadDimV, ElementAccum>(
acc_o, scores_max, scores_sum, params.scale_softmax, sink_value);
}
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, Is_dropout && Is_training, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
......@@ -1401,19 +1422,569 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_prefix_prefill_gfx938_kernel(const Params params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// GFX92A kernels
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__gfx938__)
const int bidh = blockIdx.x;
#include "fwd/gfx92a/qk_gemm_prefetch_v_mls_ds_gfx92a.h"
#include "fwd/gfx92a/pv_gemm_prefetch_k_mls_ds_gfx92a.h"
#include "fwd/gfx92a/softmax_gfx92a.h"
#include "fwd/gfx92a/fwd_epilogue_gfx92a.h"
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_attn_mha_1rowblock_gfx92a(const Params &params, const int bidb, const int bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = std::conditional_t<Is_even_MN, int32_t, int64_t>;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int WARP_K = 32;
const int bidb = blockIdx.y;
// 获取当前 TG 处理的任务大小
using BlockInfoType = flash::BlockInfo<Is_Varlen, false/*Is_Kvcache*/, false/*USE_BSHD_LAYOUT*/>;
const BlockInfoType binfo(params, bidb);
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
// 处理边界
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0 || bidh >= params.h/*border judgement*/) return;
int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
int m_block = gridDim.z - 1 - blockIdx.z;
flash::compute_attn_prefix_prefill_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
// 分配 lds
extern __shared__ Element smem[];
Element* q_lds = (Element*)&(smem);
Element* k_lds = q_lds;
Element* v_lds = k_lds;
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, flash::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
// 计算数据跨度
int seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, seqlen_o_stride;
index_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int row_offset_lse;
int headdim_split_id = 0;
fwd_prologue_compute_offset<Layout, kBlockM, kBlockN, kHeadDim, kHeadDimV, kHeadDimV, 0/*SplitD*/, Is_even_MN, false/*Is_PaddingMask*/, Params, decltype(binfo), decltype(row_offset_q)>(
seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, seqlen_o_stride, row_offset_q, row_offset_k, row_offset_v, row_offset_o, row_offset_lse,
headdim_split_id, bidb, bidh, bidh, m_block, n_block_min, binfo, params
);
#if 0
if (int(threadIdx.x) == 0) {
printf("bidb: %d | bidh: %d | actual_seqlen_q: %d | actual_seqlen_k: %d | n_block_max: %d | row_offset_q: %d | row_offset_k: %d | row_offset_v: %d | row_offset_o: %d | seqlen_q_stride: %d | seqlen_k_stride: %d | seqlen_v_stride: %d\n",
bidb, bidh, binfo.actual_seqlen_q, binfo.actual_seqlen_k, n_block_max, row_offset_q, row_offset_k, row_offset_v, row_offset_o, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride);
}
#endif
// 根据起始数据偏移量准备 Q/K/V 的 buffer resource 寄存器
auto q_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
// attention 插件: Alibi
float g_alibi;
if constexpr (Has_alibi) {
g_alibi = reinterpret_cast<ElementAccum*>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
}
// attention 插件: Dropout
unsigned long long rand_seed, rand_offset;
uint32_t p_dropout_in_8bits_value;
union_vec2_uint warp_idx_for_dropout;
if constexpr (Is_dropout and Is_training) {
rand_seed = params.rand_seed;
rand_offset = params.rand_offset + ((bidb * params.h + bidh) << 6) + (threadIdx.x & 63);
p_dropout_in_8bits_value = params.p_dropout_in_uint8_t & 0xffffffff;
warp_idx_for_dropout.u32.x = 1 * m_block * (kBlockM / 32) /* 前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/ + warp_id/*当前 block 内的 warp id*/;
if (Is_training and m_block == 0 and bidb == 0 and bidh == 0 and threadIdx.x == 0) {
params.rng_state[0] = rand_seed;
params.rng_state[1] = rand_offset;
}
}
// 预取 Q 的数据到寄存器
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2];
prefetch_q_to_vgpr_mls_ds_gfx92a<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, q_reg, warp_id, seqlen_q_stride, Is_even_MN ? 0: binfo.actual_seqlen_q - m_block * kBlockM);
// apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1: flash::ceil_div(kBlockM, kBlockN);
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr bool PREFETCH_K = Is_even_MN and kHeadDim == 128 and kHeadDimV == 128;
constexpr bool ALLOW_PREFETCH = (STAGES > 1); // 客观上决定是否开启 prefetch
if constexpr (PREFETCH_K and ALLOW_PREFETCH) {
if (n_block_min < n_block_max - n_masking_steps) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, Is_even_MN ? 0: binfo.actual_seqlen_k - n_block_min * kBlockN);
}
}
/***************************************************************************************************************************/
vec2_Accum<ElementAccum> scores_max[WARP_M / 32];
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32];
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4];
attention_initialize<kHeadDimV / kBlockK, WARP_M / 32, kBlockK / 32, 2/*M_MMAC_COUNT*/, ElementAccum>(scores_max, scores_sum, acc_o);
/***************************************************************************************************************************/
auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_gfx92a<kHeadDim, kHeadDimV, kBlockM, kBlockN, kBlockK, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_gfx92a<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_gfx92a<false, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
// Mainloop, 主循环, 不做 causal mask 的部分
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
flash::raise_priority();
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: binfo.actual_seqlen_k - warp_offset_in_seqkv;
// 预取 K 的数据到 lds
if constexpr (not PREFETCH_K) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, warp_seqkv_limit);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / WARP_N)][4];
// QK gemm
QK_GEMM_FUNC(k_ptr, v_ptr, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// Attention 变体 alibi
if constexpr (Has_alibi) {
apply_alibi_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, g_alibi);
}
// Attention 变体 local mask
if constexpr (Is_local) {
apply_mask_local_gfx938</*HasWSLeft=*/Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Attention 变体 dropout
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
apply_dropout<vec4_Accum<ElementAccum>, WARP_M, kBlockN, kNWarps, Is_even_MN>(s_reg, warp_seqkv_limit, 0, rand_seed, rand_offset, p_dropout_in_8bits_value, warp_idx_for_dropout, params.dropout_debug_count);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
// 偏移 K 指针, 提前偏移准备预取 K
*(uint64_t*)&k_ptr += kBlockN * params.k_row_stride * sizeof(Element);
// PV gemm
PV_GEMM_FUNC(v_ptr, k_ptr, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// 偏移 V 指针
*(uint64_t*)&v_ptr += kBlockN * params.v_row_stride * sizeof(Element);
}
// prefetch K 的话, 最后一次多取了一段 K, 为了不影响后续的操作, 需要同步等待
if constexpr (PREFETCH_K) {
buffer_load_lds_dwordx1_wait<0>();
}
/***************************************************************************************************************************/
// Rest loop, 做 causal mask 的部分
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: binfo.actual_seqlen_k - warp_offset_in_seqkv;
// 预取 K 的数据到 lds
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, warp_seqkv_limit);
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(kBlockN / 32) * (WARP_M / 32)][4];
// QK gemm
QK_GEMM_FUNC(k_ptr, v_ptr, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// 偏移 K 指针, 提前偏移准备预取 K
*(uint64_t*)&k_ptr += kBlockN * params.k_row_stride * sizeof(Element);
// Attention 变体 alibi
if constexpr (Has_alibi) {
apply_alibi_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, g_alibi);
}
// Attention: mask, causal mask, local mask
if constexpr (!Is_causal && !Is_local) {
if constexpr (!Is_even_MN) { apply_mask_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_seqkv_limit); }
} else {
if constexpr (Is_causal) {
apply_mask_causal_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q);
} else if constexpr (Is_local) {
apply_mask_local_gfx938</*HasWSLeft=*/Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
}
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Attention 变体 dropout
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
apply_dropout<vec4_Accum<ElementAccum>, WARP_M, kBlockN, kNWarps, Is_even_MN>(s_reg, warp_seqkv_limit, 0, rand_seed, rand_offset, p_dropout_in_8bits_value, warp_idx_for_dropout, params.dropout_debug_count);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
// PV gemm
PV_GEMM_FUNC_IN_MASK(v_ptr, k_ptr, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// 偏移 V 指针
*(uint64_t*)&v_ptr += kBlockN * params.v_row_stride * sizeof(Element);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, Is_dropout && Is_training, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
constexpr bool Is_Output_Interleave = false;
int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) {
fwd_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, Is_even_MN ? 0: binfo.actual_seqlen_q - m_block * kBlockM);
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_mls_gfx92a<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Output_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_attn_gfx92a(const Params &params) {
#if defined(__gfx92a__)
constexpr bool Do_lpt = Is_causal;
const int bidh = Do_lpt ? blockIdx.x: blockIdx.y;
const int bidb = Do_lpt ? blockIdx.y: blockIdx.z;
int warp_id_vec = threadIdx.x / 64;
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int m_block = Do_lpt ? gridDim.z - 1 - blockIdx.z: blockIdx.x;
flash::compute_attn_mha_1rowblock_gfx92a<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
#endif
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx92a(const Params &params, const int bidb, const int __bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int WARP_K = 32;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimVSplit;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int SplitD = Kernel_traits::SplitD;
constexpr int kHeadDimVOrigin = Kernel_traits::kHeadDimV;
// 获取 splitD 结果
const int bidh = __bidh / SplitD;
// 获取当前 TG 处理的任务大小
// const flash::BlockInfo</*Varlen=*/!Is_even_MN, false/*Is_kvcache*/> binfo(params, bidb);
flash::SafeDecodeBlockInfo binfo;
binfo.set_params<Params, /*Is_Q_varlen=*/true, /*Is_K_Cumulative=*/false>(params, bidb);
// 处理边界
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
// 分配 lds
extern __shared__ Element smem[];
Element* q_lds = (Element*)&(smem);
Element* k_lds = q_lds;
Element* v_lds = k_lds;
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, flash::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
// 计算数据跨度
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_o_stride = params.o_row_stride;
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int64_t row_offset_lse;
// 获取页表信息
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
const int kv_start = n_block_min * kBlockN;
const int block_table_idx = kv_start / page_block_size;
const int block_table_offset = kv_start - block_table_idx * page_block_size;
if constexpr (Layout == 1) { /*bshd layout, lse is num_heads, total_q*/
row_offset_q = (binfo.sum_s_q + m_block * kBlockM) * int64_t(seqlen_q_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(block_table[block_table_idx]) * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_lse = bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM;
row_offset_o = binfo.sum_s_q * int64_t(params.o_head_stride) * params.h + params.o_head_stride * bidh + m_block * kBlockM * seqlen_o_stride;
}
#if 0
if (int(threadIdx.x) == 0) {
printf("bidb: %d | bidh: %d | actual_seqlen_q: %d | actual_seqlen_k: %d | n_block_max: %d | row_offset_q: %ld | row_offset_k: %ld | row_offset_v: %ld | row_offset_o: %ld | seqlen_q_stride: %d | seqlen_k_stride: %d | seqlen_v_stride: %d\n",
bidb, bidh, binfo.actual_seqlen_q, binfo.actual_seqlen_k, n_block_max, row_offset_q, row_offset_k, row_offset_v, row_offset_o, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride);
}
#endif
// 根据起始数据偏移量准备 Q/K/V 的资源寄存器
auto q_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
// attention 变体: Alibi
float g_alibi;
if constexpr (Has_alibi) {
g_alibi = reinterpret_cast<ElementAccum*>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
}
// attention 插件: Dropout
unsigned long long rand_seed, rand_offset;
uint32_t p_dropout_in_8bits_value;
union_vec2_uint warp_idx_for_dropout;
if constexpr (Is_dropout and Is_training) {
rand_seed = params.rand_seed;
rand_offset = params.rand_offset + ((bidb * params.h + bidh) << 6) + (threadIdx.x & 63);
p_dropout_in_8bits_value = params.p_dropout_in_uint8_t & 0xffffffff;
warp_idx_for_dropout.u32.x = 1 * m_block * (kBlockM / 32) /* 前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/ + warp_id/*当前 block 内的 warp id*/;
if (Is_training and m_block == 0 and bidb == 0 and bidh == 0 and threadIdx.x == 0) {
params.rng_state[0] = rand_seed;
params.rng_state[1] = rand_offset;
}
}
// 预取 Q 的数据到寄存器
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2];
prefetch_q_to_vgpr_mls_ds_gfx92a<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, q_reg, warp_id, seqlen_q_stride, Is_even_MN ? 0: binfo.actual_seqlen_q - m_block * kBlockM);
// apply causal mask 的步骤和 no causal mask 的步骤分开算
// prefix prefill 目前没分开算, 明确边界的情况下也可以分开算, 性能会有提升
int n_masking_steps = (!Is_causal && !Is_local) ? 1: min(n_block_max, flash::ceil_div(kBlockM, kBlockN) + 1);
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr bool PREFETCH_K = Is_even_MN and kHeadDim == 128 and kHeadDimV == 128;
constexpr bool ALLOW_PREFETCH = (STAGES > 1); // 客观上决定是否开启 prefetch
if constexpr (PREFETCH_K and ALLOW_PREFETCH) {
if (n_block_min < n_block_max - n_masking_steps) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, Is_even_MN ? 0: binfo.actual_seqlen_k - n_block_min * kBlockN);
}
}
/***************************************************************************************************************************/
vec2_Accum<ElementAccum> scores_max[WARP_M / 32];
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32];
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4];
attention_initialize<kHeadDimV / kBlockK, WARP_M / 32, kBlockK / 32, 2/*M_MMAC_COUNT*/, ElementAccum>(scores_max, scores_sum, acc_o);
/***************************************************************************************************************************/
auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_gfx92a<kHeadDim, kHeadDimV, kBlockM, kBlockN, kBlockK, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_gfx92a<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_gfx92a<false, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
// Mainloop, 主循环, 不做 causal mask 的部分
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: binfo.actual_seqlen_k - warp_offset_in_seqkv;
// 预取 K 的数据到 lds
if constexpr (not PREFETCH_K) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, warp_seqkv_limit);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / WARP_N)][4];
// QK gemm
QK_GEMM_FUNC(k_ptr, v_ptr, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// Attention 变体 alibi
if constexpr (Has_alibi) {
apply_alibi_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, g_alibi);
}
// Attention: mask, causal mask, local mask
if constexpr (Is_local) {
apply_mask_local_gfx92a</*HasWSLeft=*/Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Attention 变体 dropout
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
apply_dropout<vec4_Accum<ElementAccum>, WARP_M, kBlockN, kNWarps, Is_even_MN>(s_reg, warp_seqkv_limit, 0, rand_seed, rand_offset, p_dropout_in_8bits_value, warp_idx_for_dropout, params.dropout_debug_count);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
if constexpr (PREFETCH_K) {
*(int64_t*)&k_ptr += (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
}
// PV gemm
PV_GEMM_FUNC(v_ptr, k_ptr, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
if constexpr (not PREFETCH_K) {
*(int64_t*)&k_ptr += (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
}
*(int64_t*)&v_ptr += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * int64_t(params.v_row_stride)) * sizeof(Element);
}
// prefetch K 的话, 最后一次多取了一段 K, 为了不影响后续的操作, 需要同步等待
if constexpr (PREFETCH_K) { buffer_load_lds_dwordx1_wait<0>(); }
/***************************************************************************************************************************/
// Rest loop, 做 causal mask 的部分
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
// #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: binfo.actual_seqlen_k - warp_offset_in_seqkv;
// 预取 K 的数据到 lds
if constexpr (true) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, warp_seqkv_limit);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / WARP_N)][4];
// QK gemm
QK_GEMM_FUNC(k_ptr, v_ptr, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// Attention 变体 alibi
if constexpr (Has_alibi) {
apply_alibi_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, g_alibi);
}
// Attention: mask, causal mask, local mask
if constexpr (!Is_causal && !Is_local) {
if constexpr (!Is_even_MN) { apply_mask_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_seqkv_limit); }
} else if constexpr (Is_causal) {
apply_mask_causal_gfx92a<vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q);
} else if constexpr (Is_local) {
apply_mask_local_gfx92a</*HasWSLeft=*/Is_local, vec4_Accum<ElementAccum>, WARP_M, kBlockN>(s_reg, warp_offset_in_seqkv, binfo.actual_seqlen_k, warp_offset_in_seq_q, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Attention 变体 dropout
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
apply_dropout<vec4_Accum<ElementAccum>, WARP_M, kBlockN, kNWarps, Is_even_MN>(s_reg, warp_seqkv_limit, 0, rand_seed, rand_offset, p_dropout_in_8bits_value, warp_idx_for_dropout, params.dropout_debug_count);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
// PV gemm
PV_GEMM_FUNC_IN_MASK(v_ptr, k_ptr, v_lds, k_lds, p_reg, acc_o, warp_id, seqlen_k_stride, seqlen_v_stride, warp_seqkv_limit);
// 偏移 V 指针
*(int64_t*)&k_ptr += (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
*(int64_t*)&v_ptr += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * int64_t(params.v_row_stride)) * sizeof(Element);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum<ElementAccum> lse[WARP_M / 32];
fwd_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, Is_dropout && Is_training, ElementAccum>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, params.rp_dropout);
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) {
fwd_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, binfo.actual_seqlen_q - m_block * kBlockM);
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_mls_gfx92a<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, false/*Is_Interleave*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_prefix_prefill_gfx938_kernel(const Params params) {
#if defined(__gfx938__) || defined(__gfx946__)
const int bidh = blockIdx.x;
const int bidb = blockIdx.y;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int m_block = gridDim.z - 1 - blockIdx.z;
flash::compute_attn_prefix_prefill_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
#endif
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_prefix_prefill_gfx92a_kernel(const Params params) {
#if defined(__gfx92a__)
const int bidh = blockIdx.x;
const int bidb = blockIdx.y;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int m_block = gridDim.z - 1 - blockIdx.z;
flash::compute_attn_prefix_prefill_1rowblock_gfx92a<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
#endif
}
......
......@@ -358,7 +358,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_tile16x32(const Params
/**********************************************************************************************************************************/
// 主循环, 沿着 seqlenKV 维度, 每次 4 个 wave 共同计算一个 kBLOCKN
const int n_block_min = 0;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
const int n_block_max = (Split and !MLA_FIX_NUM_SPLITS) ? ceil_div(Partition_Size, kBlockN): ceil_div(binfo.actual_seqlen_k, kBlockN);
#else
const int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN); // temp workaround, unroll partition size for zd may lead to wrong results
......@@ -451,15 +451,15 @@ inline __device__ void compute_attn_splitkv_mla(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#include "kvcache/gfx938/kvcache_qk_gemm_prefetch_v_gfx938.h"
#include "kvcache/gfx938/kvcache_pv_gemm_prefetch_k_gfx938.h"
#include "kvcache/gfx938/kvcache_softmax_gfx938.h"
#include "kvcache/gfx938/kvcache_epilogue_gfx938.h"
#include "kvcache/kvcache_acco_reduce_tile16x32.h"
#include "kvcache/kvcache_epilogue.h"
#include "mla/gfx938/fp8_mla_acco_reduce_gfx938.h"
#include "mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h"
#include "mla/gfx938/mla_tp8_epilogue_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_qk_gemm_utils_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_qk_gemm_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_pv_gemm_gfx938.h"
// For FlashMLA, codes almostly copy codes from paged_attention with a few differences.
// Kernel codes listed below can be customized alone if neccessary.
// sgpr: 75, vgpr: 240 | base sgpr: 80, vgpr 254
......@@ -546,7 +546,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
bool is_thread0 = threadIdx.x == 0;
if (is_thread0) {
inline_utcl2_warmup_dword(k_addr);
// inline_utcl2_warmup_dword(k_addr);
}
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
......@@ -584,11 +584,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
int warp_offset_in_seqkv = n_block_loop * kBlockN + warp_id * WARP_N;
int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_loop * kBlockN;
kvcache_prefetch_k_to_lds_gfx938<kBlockK, WARP_N, Element, STAGES, WARP_NUM>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
f16_mla_tp8_prefetch_k_to_lds_gfx938<kBlockK, WARP_N, Element, STAGES, WARP_NUM>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
vec4_Accum<ElementAccum> s_reg[M_WARP_COUNT * N_WARP_COUNT][4];
kvcache_qk_gemm_prefetch_v_gfx938<kHeadDim, kHeadDimVSplit, kBlockM, WARP_N, kBlockK, WARP_M, WARP_N, WARP_NUM, STAGES, M_MMAC_COUNT, Element, ElementAccum>(
f16_mla_tp8_qk_gemm_gfx938<kHeadDim, kHeadDimVSplit, kBlockM, WARP_N, kBlockK, WARP_M, WARP_N, WARP_NUM, STAGES, M_MMAC_COUNT, Element, ElementAccum>(
q_addr, k_addr, v_addr, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit);
if constexpr (Is_causal) {
......@@ -603,7 +603,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * N_WARP_COUNT][4];
mla_convert_pk_type<M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT, Element, ElementAccum>(p_reg, s_reg);
kvcache_pv_gemm_prefetch_k_gfx938<K_LOOP_COUNT, kBlockM, kBlockK, kBlockN, M_WARP_COUNT, K_WARP_COUNT/*kBlockK*/, N_WARP_COUNT/*WARP_N*/, STAGES, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
f16_mla_tp8_pv_gemm_gfx938<K_LOOP_COUNT, kBlockM, kBlockK, kBlockN, M_WARP_COUNT, K_WARP_COUNT/*kBlockK*/, N_WARP_COUNT/*WARP_N*/, STAGES, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
v_addr, k_addr, v_lds, k_lds, p_reg, acc_o, warp_id, vcache_seqlen_stride, warp_seqkv_limit);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
......@@ -646,7 +646,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
template<typename Kernel_traits, bool Is_causal, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_mla_gfx938(const Params &params) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
......@@ -786,7 +786,7 @@ inline __device__ void flash_fwd_mla_prefix_prefill_kernel_base(const Params par
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, bool Is_prefix, bool Is_causal, typename Element, typename ElementAccum, typename Params>
__global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_fix_kernel(const Params params) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int q_blocks = params.q_blocks;
for(int loop = blockIdx.x; loop < params.total_blocks; loop += params.cu_count) {
int m_block = loop % q_blocks;
......@@ -818,7 +818,7 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_fix_kerne
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, bool Is_prefix, bool Is_causal, typename Element, typename ElementAccum, typename Params>
__global__ void __launch_bounds__(512, 1) flash_fwd_mla_fix_kernel(const Params params) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int q_blocks = params.q_blocks;
for(int loop = blockIdx.x; loop < params.total_blocks; loop += params.cu_count) {
int m_block = loop % q_blocks;
......@@ -976,7 +976,7 @@ inline __device__ void flash_fwd_mla_fast_prefix_prefill_kernel_base(const Param
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, bool Is_prefix, bool Is_causal, typename Element, typename ElementAccum, typename Params>
__global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_kernel(const Params params) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
const int q_blocks = params.q_blocks;
for (int m_block = 0; m_block < q_blocks; ++m_block) {
// 获取当前任务
......@@ -996,7 +996,7 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_kernel(co
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, bool Is_prefix, bool Is_causal, typename Element, typename ElementAccum, typename Params>
__global__ void __launch_bounds__(512, 1) flash_fwd_mla_kernel(const Params params) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
// 获取当前任务
const int q_blocks = params.q_blocks;
for (int m_block = 0; m_block < q_blocks; ++m_block) {
......@@ -1132,22 +1132,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......@@ -1391,22 +1385,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......@@ -1662,22 +1650,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......
......@@ -89,7 +89,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv(const Params &params,
const int64_t row_offset_q = Is_Varlen
? binfo.sum_s_q * ngroups * query_seqlen_stride + bidh * params.q_head_stride + m_block * kBlockM * query_seqlen_stride
: bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * query_seqlen_stride;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
constexpr bool USE_CACHE_SWIZZLE = false;
#else
constexpr bool USE_CACHE_SWIZZLE = true; // for gfx928, cache swizzle have significant influence
......@@ -292,7 +292,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
#include "kvcache/kvcache_softmax_tile16x32.h"
#include "kvcache/kvcache_acco_reduce_tile16x32.h"
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &params, const int bidb, const int bidh, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -385,7 +385,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
: bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * query_seqlen_stride;
// 准备读取数据的 buffer resource 寄存器
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
constexpr bool USE_CACHE_SWIZZLE = false;
#else
constexpr bool USE_CACHE_SWIZZLE = true; // for gfx928, cache swizzle have significant influence
......@@ -497,7 +497,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
int lane_id = thread_id & 63;
if constexpr (WARP_NUM > 1) {
int reduced_q_len = Is_Varlen ? params.seqlen_q: actual_seqlen_q;
kvcache_acco_reduce_tile16x32<REUSE_KV_TIMES, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, 4/*Padding*/, ElementAccum>(acc_o, acc_o_lds, reduced_q_len, warp_id, lane_id);
kvcache_acco_reduce_tile16x32<K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, 4/*Padding*/, ElementAccum>(acc_o, acc_o_lds, reduced_q_len, warp_id, lane_id);
}
/**********************************************************************************************************************************/
......@@ -525,7 +525,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_tile16x32(const Params &params) {
// The block index for the head.
......@@ -537,7 +537,7 @@ inline __device__ void compute_attn_splitkv_tile16x32(const Params &params) {
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_tile16x32<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
flash::compute_attn_1rowblock_splitkv_tile16x32<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
}
......@@ -550,7 +550,7 @@ inline __device__ void compute_attn_splitkv_tile16x32(const Params &params) {
#include "kvcache/gfx938/kvcache_softmax_gfx938.h"
#include "kvcache/gfx938/kvcache_epilogue_gfx938.h"
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &params, const int bidb, const int bidh, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -574,9 +574,9 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
binfo.set_params<Params, /*Is_Q_varlen=*/Is_Varlen, /*Is_K_Cumulative=*/false>(params, bidb);
// splitKV, 根据 split id 确定当前 split 在 seqlen_kv 上处理的长度
int split_id;
int split_id = 0;
int original_actual_seqlen_k = binfo.actual_seqlen_k;
int partition_size;
int partition_size = 0;
if constexpr (Split) {
split_id = blockIdx.y;
if constexpr (Is_Varlen) {
......@@ -642,10 +642,12 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
bool is_thread0 = threadIdx.x == 0;
if (is_thread0) {
inline_utcl2_warmup_dword(q_addr);
inline_utcl2_warmup_dword(k_addr);
inline_utcl2_warmup_dword(v_addr);
// inline_utcl2_warmup_dword(q_addr);
// inline_utcl2_warmup_dword(k_addr);
// inline_utcl2_warmup_dword(v_addr);
}
// Keep warmup buffer loads out of the MLS vmcnt schedule below.
flash::wait_all_buffer_data_arrived<true>();
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
int row_offset_lse;
......@@ -746,12 +748,17 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
int thread_id = threadIdx.x;
int lane_id = thread_id & 63;
if constexpr (WARP_NUM > 1) {
kvcache_acco_reduce_tile16x32<REUSE_KV_TIMES, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, 0/*Padding*/, ElementAccum>(acc_o, acc_o_lds, params.seqlen_q, warp_id, lane_id);
kvcache_acco_reduce_tile16x32<K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, 0/*Padding*/, ElementAccum>(acc_o, acc_o_lds, params.seqlen_q, warp_id, lane_id);
}
/**********************************************************************************************************************************/
// Epilogue, 收尾工作
// 收尾 1: 根据最后的归一化求和, 做 rescale
if (params.s_aux_ptr != nullptr && split_id == 0) {
fp8_kvcache_apply_attention_sink_gfx938<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
acc_o, scores_max, scores_sum, params.s_aux_ptr, params.s_aux_type,
bidh, params.h, ngroups, m_block, kBlockM, lane_id, params.scale_softmax);
}
kvcache_epilugue_rescale_acco<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(acc_o, scores_sum);
// 收尾 2: splitkv, 或者开启 debug 的情况下, 写出 scores_max, scores_sum
......@@ -776,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_gfx938(const Params &params) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
......@@ -789,11 +796,272 @@ inline __device__ void compute_attn_splitkv_gfx938(const Params &params) {
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
flash::compute_attn_1rowblock_splitkv_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MLS-based Paged Attention, gfx92a
////////////////////////////////////////////////////////////////////////////////////////////////////
#include "kvcache/gfx92a/f16_kvcache_gfx92a.h"
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Is_monopolize, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_gfx92a(const Params &params, const int bidb, const int bidh, const int warp_id) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockN / WARP_N;
constexpr int kHeadDimVSplit = kHeadDimV / HEADDIM_V_SPLIT;
// flash::BlockInfo</*Varlen=*/true, /*Is_Kvcache*/true> binfo(params, bidb);
flash::SafeDecodeBlockInfo binfo;
binfo.set_params<Params, /*Is_Q_varlen=*/Is_Varlen, /*Is_K_Cumulative=*/false>(params, bidb);
// SplitKV processing
int split_id;
int original_actual_seqlen_k = binfo.actual_seqlen_k;
int partition_size;
if constexpr (Split) {
split_id = blockIdx.y;
if constexpr (Is_Varlen) {
partition_size = splitkv_get_partitionsize_of_fix_numsplits(binfo.actual_seqlen_k, params.num_splits);
binfo.actual_seqlen_k = min(binfo.actual_seqlen_k - split_id * partition_size, partition_size);
} else {
partition_size = params.partition_size;
int num_splits = max(1, floor_div(binfo.actual_seqlen_k, partition_size));
binfo.actual_seqlen_k = (split_id == num_splits - 1)
? binfo.actual_seqlen_k - split_id * partition_size: partition_size;
binfo.actual_seqlen_k = (split_id >= num_splits) ? 0: binfo.actual_seqlen_k;
if (split_id >= num_splits) return;
}
}
// acquire TG id
int block_x = blockIdx.x;
const int m_block = block_x / HEADDIM_V_SPLIT;
const int headdim_split_id = block_x & (HEADDIM_V_SPLIT - 1);
// Compute seqQ
int ngroups, actual_seqlen_q;
if constexpr (Is_Varlen) {
ngroups = params.ngroups;
actual_seqlen_q = binfo.actual_seqlen_q * ngroups;
} else {
actual_seqlen_q = binfo.actual_seqlen_q;
}
// Running boundaries
if (m_block * kBlockM >= actual_seqlen_q || binfo.actual_seqlen_k <= 0) return;
// Decide lsa usage
extern __shared__ Element smem[];
Element* q_lds = reinterpret_cast<Element*>(smem);
Element* k_lds = reinterpret_cast<Element*>(smem);
Element* v_lds = Is_monopolize ? k_lds + 16384: k_lds;
ElementAccum* acc_o_lds = reinterpret_cast<ElementAccum*>(smem);
ElementAccum* max_lds = acc_o_lds + 1024/*from 4096 bytes*/;
// Acquire stride along seq dimension of q/k/v
int query_seqlen_stride = params.q_row_stride;
int kcache_seqlen_stride = params.k_row_stride;
int vcache_seqlen_stride = params.v_row_stride;
// Compute q and k/v block table address
int page_block_size = params.page_block_size;
int this_split_seqlen_start = Split ? split_id * partition_size: 0;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
block_table = block_table + (Split ? ceil_div(this_split_seqlen_start, page_block_size) : 0);
const int block_table_idx = 0;
const int block_table_offset = 0;
const int64_t row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const int64_t row_offset_v = int64_t(block_table[block_table_idx]) * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int64_t row_offset_q = Is_Varlen
? binfo.sum_s_q * ngroups * query_seqlen_stride + bidh * ngroups * params.q_head_stride + m_block * kBlockM * query_seqlen_stride
: bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * query_seqlen_stride;
// Prepare buffer resource for q/k/v
Element* q_ptr = reinterpret_cast<Element*>(params.q_ptr) + row_offset_q;
auto q_addr = prepare_for_buffer_load<kHeadDim, Element, false>(q_ptr);
auto k_addr = prepare_for_buffer_load<kHeadDim, Element, false>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_addr = prepare_for_buffer_load<kHeadDimV, Element, false>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v + headdim_split_id * kHeadDimVSplit);
// utcl2 warmup
if constexpr (false) {
bool is_thread0 = threadIdx.x == 0;
if (is_thread0) {
inline_utcl2_warmup_dword(q_addr);
inline_utcl2_warmup_dword(k_addr);
inline_utcl2_warmup_dword(v_addr);
}
}
// Compute lse/max/sum pointers of this TG
int row_offset_lse;
ElementAccum * scores_sum_ptr;
ElementAccum * scores_max_ptr;
ElementAccum * softmax_lse_ptr;
if constexpr (Split) {
int row_offset_scores_split;
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.h * ngroups * params.total_q);
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) + row_offset_lse + row_offset_scores_split;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.b * params.h * params.seqlen_q);
scores_sum_ptr = reinterpret_cast<ElementAccum*>(params.scores_sum_ptr) + row_offset_lse + row_offset_scores_split;
scores_max_ptr = reinterpret_cast<ElementAccum*>(params.scores_max_ptr) + row_offset_lse + row_offset_scores_split;
}
} else {
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
}
}
// hold q regs
constexpr int M_WARP_COUNT = WARP_M / 32;
constexpr int K_WARP_COUNT = kBlockK / 32;
constexpr int N_WARP_COUNT = WARP_N / 32;
constexpr int K_LOOP_COUNT = kHeadDimVSplit / kBlockK;
constexpr int Q_LOAD_BLOCKS = STAGES == 2 ? (kHeadDim / kBlockK): 1;
union_vec4_f16x2<Element> q_reg[Q_LOAD_BLOCKS * M_WARP_COUNT * K_WARP_COUNT * 2];
// prefetch Q into vgprs, can be hide
gfx92a::kvcache_prefetch_q_to_vgpr<Is_Varlen, kHeadDim, kBlockK, WARP_M, WARP_NUM, M_MMAC_COUNT, Element>(
q_ptr, q_lds, q_reg, warp_id, query_seqlen_stride, params.q_head_stride, ngroups, actual_seqlen_q - m_block * kBlockM);
// Initialize, scores_max/scores_max/acc_o
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT];
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT];
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4];
attention_initialize<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(scores_max, scores_sum, acc_o);
// Mainloop, along seqlenkv dimension, 4 warps computes attention of a kBlockN
const int n_block_min = 0;
const int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
int n_block_loop = n_block_min;
for (; n_block_loop < n_block_max; ++n_block_loop) {
int warp_offset_in_seqkv = n_block_loop * kBlockN + warp_id * WARP_N;
int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_loop * kBlockN;
constexpr int prefetchKLevel = 4;
constexpr int prefetchVLevel = Is_monopolize ? 4: 2;
constexpr bool prefetchK = Is_monopolize;
if constexpr (prefetchK) {
if (n_block_loop == n_block_min) gfx92a::kvcache_prefetch_k_to_lds<kBlockK, WARP_N, prefetchKLevel, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
} else {
gfx92a::kvcache_prefetch_k_to_lds<kBlockK, WARP_N, prefetchKLevel, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
}
vec4_Accum<ElementAccum> s_reg[M_WARP_COUNT * N_WARP_COUNT][4];
gfx92a::kvcache_qk_gemm_prefetch_v<kHeadDim, kHeadDimVSplit, WARP_N, kBlockK, WARP_M, WARP_N, prefetchKLevel, prefetchVLevel, M_MMAC_COUNT, Element, ElementAccum>(
k_addr, v_addr, k_lds, v_lds, q_reg, s_reg, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int offset_diff = block_table_offset_next - block_table_offset_cur;
int table_diff;
int table_cur, table_next;
if constexpr (prefetchK) {
inline_global_load_dwordx1(table_cur, block_table_idx_cur, block_table);
inline_global_load_dwordx1(table_next, block_table_idx_next, block_table);
}
gfx92a::kvcache_prefetch_v_to_lds<kHeadDimV, kBlockK, kBlockK, STAGES, prefetchVLevel, Element>(v_addr, v_lds, warp_id, vcache_seqlen_stride, warp_seqkv_limit);
if constexpr (Is_causal) {
gfx92a::kvcache_apply_mask_causal<M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT, Is_Varlen>(s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, ngroups, params.mtp, params.layout);
} else {
gfx92a::kvcache_apply_mask<M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(s_reg, warp_seqkv_limit, warp_id * WARP_N);
}
mla_softmax_rescale_o<Is_causal, ElementAccum, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, N_WARP_COUNT, WARP_NUM, M_MMAC_COUNT>(
s_reg, scores_max, scores_sum, acc_o, max_lds, warp_id, params.scale_softmax_log2);
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * N_WARP_COUNT][4];
gfx92a::convert_attn_f32_to_f16<M_MMAC_COUNT, Element, ElementAccum>(s_reg, p_reg);
if constexpr (prefetchK) {
flash::wait_buffer_data_arrived<true/*can be false*/>(prefetchVLevel/*4 for hdim 128*/);
table_diff = __builtin_amdgcn_readfirstlane(table_next - table_cur);
} else {
table_diff = __builtin_amdgcn_readfirstlane(block_table[block_table_idx_next] - block_table[block_table_idx_cur]);
}
*(int64_t*)&k_addr += (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * params.k_row_stride) * sizeof(Element);
gfx92a::kvcache_pv_gemm_prefetch_k<prefetchK, K_LOOP_COUNT, kBlockK, kBlockN, M_WARP_COUNT, K_WARP_COUNT, N_WARP_COUNT, STAGES, prefetchKLevel, prefetchVLevel, M_MMAC_COUNT, Element, ElementAccum>(
v_addr, k_addr, v_lds, k_lds, p_reg, acc_o, warp_id, vcache_seqlen_stride, kcache_seqlen_stride, warp_seqkv_limit);
*(int64_t*)&v_addr += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * params.v_row_stride) * sizeof(Element);
}
// reduce pv results among 4 warps
int thread_id = threadIdx.x;
int lane_id = thread_id & 63;
if constexpr (WARP_NUM > 1) {
gfx92a::kvcache_acco_reduce_tile16x32<K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, ElementAccum>(acc_o, acc_o_lds, params.seqlen_q, warp_id, lane_id);
}
/**********************************************************************************************************************************/
// Epilogue 1: rescaling acc_o
kvcache_epilugue_rescale_acco<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(acc_o, scores_sum);
// Epilogue 2: store lse / max / sum for splitkv reduction
if constexpr (Is_Varlen) {
kvcache_epilogue_store_softmax_lse<Is_Varlen, true, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, softmax_lse_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM, params.total_q, params.ngroups);
} else {
kvcache_epilogue_store_max_sum<Split, true/*Is_16x32*/, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, scores_max_ptr, scores_sum_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM);
}
// Epilogue 3: store acc_o into global memory
int64_t row_offset_o = Is_Varlen ? binfo.sum_s_q * ngroups * int64_t(params.o_row_stride) + bidh * ngroups * params.o_head_stride + headdim_split_id * kHeadDimVSplit + (Split ? split_id * params.ngroups * int64_t(params.total_q) * params.o_row_stride: 0)
: bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit + (Split ? split_id * params.b * params.o_batch_stride: 0);
gfx92a::kvcache_varlen_epilogue_store_output<Is_Varlen, Split, kBlockK, WARP_NUM, K_LOOP_COUNT, M_MMAC_COUNT, SplitkvAccumType, ElementAccum>(
acc_o, params, row_offset_o, actual_seqlen_q - m_block * kBlockM, warp_id, lane_id);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Is_monopolize, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, typename Params>
inline __device__ void compute_attn_splitkv_gfx92a(const Params &params) {
#if defined(__gfx92a__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
// The block index for the batch.
const int bidb = Split ? blockIdx.z / params.h : blockIdx.z;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_gfx92a<Kernel_traits, Is_causal, Is_Varlen, Is_monopolize, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT, Params>(params, bidb, bidh, warp_id);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FMA-based Paged Attention
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -864,7 +1132,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_fma_kernel(Param
ElementAccum scores_sum = 0;
// 准备必要的 lds
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__shared__ ElementAccum lds[4096]; // 16384 bytes, allow 4 waves per simd
#else
__shared__ ElementAccum lds[16384]; // 65536 bytes, allow 1 waves per simd for zd
......@@ -1095,7 +1363,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_kernel(Params pa
ElementAccum scores_sum = 0;
// 准备必要的 lds
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__shared__ ElementAccum lds[4096]; // 16384 bytes, allow 4 waves per simd
#else
__shared__ ElementAccum lds[16384]; // 65536 bytes, allow 1 waves per simd for zd
......
......@@ -284,4 +284,576 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_int8_prefix_prefill_kernel(c
compute_attn_int8_prefix_prefill_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// GFX938 kernels
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool FP8_DEBUG, bool Is_even_MN, int kBlockM, int kBlockN, int WARP_M, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_debug_p_reg(
Element* p_reg_ptr,
union_vec32_fp8 p_reg[WARP_M / 16],
int bidb,
int bidh,
int h,
int actual_seqlen_q,
int actual_seqlen_k,
int max_seq_q_offset,
int max_seq_kv_offset,
int m_block,
int n_block_loop,
int warp_id,
int lane_id
) {
if constexpr (FP8_DEBUG) {
__builtin_amdgcn_sched_barrier(0);
if constexpr (FP8_DEBUG) {
Element* p_reg_buffer = p_reg_ptr + (bidb * h + bidh) * actual_seqlen_q * actual_seqlen_k;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
int row_pos = m_block * kBlockM + warp_id * WARP_M + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);
int col_pos = (lane_id >> 4) * 8 + n_idx * 4 + k_loop * WARP_N + n_block_loop * kBlockN;
*(int32_t*)(p_reg_buffer + row_pos * actual_seqlen_k + col_pos) = p_reg[m_idx].i32[k_loop * 2 + n_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__builtin_amdgcn_sched_barrier(0);
}
}
#include "fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h"
#include "fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h"
#include "fwd/gfx938/fp8_softmax_gfx938.h"
#include "fwd/gfx938/fp8_epilogue.h"
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_fp8_attn_mha_1rowblock_gfx938(const Params &params, const int bidb, const int bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using Element_k = typename Kernel_traits::Element_k;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int WARP_K = 32;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
// 获取当前 TG 处理的任务大小
const flash::BlockInfo</*Varlen=*/Is_Varlen> binfo(params, bidb);
// 判断任务边界
int max_seq_q_offset = binfo.actual_seqlen_q - m_block * kBlockM;
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k <= 0/* || bidh >= h*/) return;
// 获取 wave id
// int __warp_id = threadIdx.x >> 6;
// int warp_id = __builtin_amdgcn_readfirstlane(__warp_id);
// 定义 lds, 128x128 个 fp8, 16384 bytes
// __shared__ int8_t lds[16384 + 4096 + 16384 + 4096];
extern __shared__ int8_t lds[];
int8_t* q_lds = lds + 0;
int8_t* k_lds = lds + 0;
int8_t* v_lds = lds + 0;
// ========================================== 计算 offset ===========================================
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int64_t row_offset_lse_base;
if constexpr (Is_Varlen) {
if constexpr (Layout == 1) { /* bshd: q/o are [total_q, h, d] */
row_offset_q = (int64_t(binfo.sum_s_q) + m_block * kBlockM) * int64_t(params.q_row_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(binfo.sum_s_k) * int64_t(params.k_row_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(binfo.sum_s_k) * int64_t(params.v_row_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = int64_t(binfo.sum_s_q) * int64_t(params.o_head_stride) * params.h + params.o_head_stride * bidh + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = bidh * int64_t(params.total_q) + binfo.sum_s_q;
} else { /* bhsd */
row_offset_q = int64_t(binfo.sum_s_q) * int64_t(params.q_row_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(params.q_row_stride);
row_offset_k = int64_t(binfo.sum_s_k) * int64_t(params.k_row_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(binfo.sum_s_k) * int64_t(params.v_row_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = int64_t(binfo.sum_s_q) * int64_t(params.o_row_stride) + bidh * params.o_head_stride + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = bidh * int64_t(params.total_q) + binfo.sum_s_q;
}
} else {
row_offset_q = bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(params.q_row_stride);
row_offset_k = bidb * int64_t(params.k_batch_stride) + int(bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = bidb * int64_t(params.v_batch_stride) + int(bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + m_block * kBlockM * int64_t(params.o_row_stride);
row_offset_lse_base = (bidb * params.h + bidh) * int64_t(binfo.actual_seqlen_q);
}
Element_k* q_ptr = reinterpret_cast<Element_k*>(params.q_ptr) + row_offset_q;
Element_k* k_ptr = reinterpret_cast<Element_k*>(params.k_ptr) + row_offset_k;
Element_k* v_ptr = reinterpret_cast<Element_k*>(params.v_ptr) + row_offset_v;
ElementAccum* q_descale_ptr = reinterpret_cast<ElementAccum*>(params.q_descale_ptr);
ElementAccum* k_descale_ptr = reinterpret_cast<ElementAccum*>(params.k_descale_ptr);
ElementAccum* v_descale_ptr = reinterpret_cast<ElementAccum*>(params.v_descale_ptr);
ElementAccum q_descale = q_descale_ptr[0];
ElementAccum k_descale = k_descale_ptr[0];
ElementAccum qk_descale = q_descale * k_descale;
ElementAccum softmax_scale = params.scale_softmax * qk_descale;
ElementAccum softmax_scale_log2 = params.scale_softmax_log2 * qk_descale;
ElementAccum v_descale = v_descale_ptr[0];
// acc_o_ptr = reinterpret_cast<ElementAccum*>(acc_o_ptr) + row_offset_o;
ElementAccum* softmax_lse_ptr = reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr);
Element_k* p_reg_ptr = reinterpret_cast<Element_k *>(params.p_ptr);
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds<Is_even_MN, kHeadDim, WARP_M, Element_k>(q_ptr, q_lds, warp_id, params.q_row_stride, max_seq_q_offset);
// 计算解决 bank 冲突必须的一些变量
int tx = threadIdx.x;
int lane_id = tx & 63;
// 准备存储最大值, 求和, acc_o 寄存器 等
ElementAccum scores_max[WARP_M / 16];
ElementAccum scores_sum[WARP_M / 16];
vec4_Accum<ElementAccum> acc_o[kHeadDimV / 32][WARP_M / 16][WARP_N / 16];
fp8_attention_initialize<kHeadDimV, WARP_M, WARP_N, ElementAccum>(scores_max, scores_sum, acc_o);
// 从 lds 读取 q 的数据, 不需要同步
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64];
load_q_from_lds_to_vgpr<kHeadDim, WARP_M, Element_k>(q_regs, q_lds, warp_id, lane_id);
// ======================================================== Prefetch K ======================================================================
fp8_prefetch_k_to_lds<Is_even_MN, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, binfo.actual_seqlen_k);
// ======================================================== Mainloop ======================================================================
// 计算当前 block 计算任务的边界,带 causal mask 的场景可以少计算一些
int n_block_min = 0;
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal) {
n_block_max = std::min(n_block_max, ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + 0/*params.window_size_right*/, kBlockN));
}
constexpr int n_masking_steps = (!Is_causal/* && !Is_local*/) ? 1: ceil_div(kBlockM, kBlockN); // 目前的场景可能需要限制 kBlockM == kBlockN, 主要是考虑到 prefetch K 的数据正确性
constexpr bool Assume_valid_rows = !Is_local && (!Is_causal || !Is_Varlen);
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
// 计算 kv 的边界
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// ======================================================== QK gemm ======================================================================
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds<Is_even_MN, kBlockN, kHeadDimV, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== Softmax ======================================================================
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDimV, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// ========================================== cvt ===============================================
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<1, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
fp8_pv_gemm_and_prefetch_k<true/*PrefetchK*/, Is_even_MN, kHeadDim, kHeadDimV, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset - kBlockN);
// 计算 k, v 的偏移
v_ptr += kBlockN * params.v_row_stride;
}
// ========================================== Rest ===============================================
// 剩下的需要做 causal mask
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
// 计算 kv 的边界
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// ======================================================== QK gemm ======================================================================
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds<Is_even_MN, kBlockN, kHeadDimV, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// ======================================================== causal mask ==================================================================
if constexpr (Is_causal) {
fp8_apply_causal_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN, lane_id);
}
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== mask ==================================================================
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if constexpr (!Is_causal && !Is_local) {
if constexpr (!Is_even_MN) {
fp8_apply_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, max_seq_kv_offset, 0, lane_id);
}
}
// ======================================================== Softmax ======================================================================
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDimV, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// ========================================== cvt ===============================================
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<0, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
constexpr bool PrefetchK = n_masking_steps > 1;
fp8_pv_gemm_and_prefetch_k<PrefetchK, Is_even_MN, kHeadDim, kHeadDimV, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset - kBlockN);
// 计算 k, v 的偏移
if (not PrefetchK) {
k_ptr += kBlockN * params.k_row_stride;
}
v_ptr += kBlockN * params.v_row_stride;
}
// ========================================== rescale by scores_sum ==========================================
// 根据 scores_sum 对 acc_o 做缩放
ElementAccum lse[WARP_M / 16];
if (params.s_aux_ptr != nullptr) {
const float sink_value = fp8_attention_sink_load(params.s_aux_ptr, params.s_aux_type, bidh);
fp8_attention_sink_apply<kHeadDimV, WARP_M, WARP_N, ElementAccum>(
acc_o, scores_max, scores_sum, softmax_scale, sink_value);
}
if constexpr (Return_softmax) {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDimV, WARP_M, WARP_N, true/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
} else {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDimV, WARP_M, WARP_N, false/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
}
// ========================================== lse storation ==========================================
if constexpr (Return_softmax) {
fp8_epilogue_store_lse<Is_even_MN, WARP_M, ElementAccum>(
softmax_lse_ptr, scores_max, scores_sum, lse, row_offset_lse_base, binfo.actual_seqlen_q, m_block * kBlockM + warp_id * WARP_M, lane_id);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output<Is_even_MN, kBlockM, kHeadDimV, WARP_M, WARP_N, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, params.o_row_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, bool Is_GQA, int Layout, typename Params>
inline __device__ void compute_fp8_attn_gfx938(const Params &params) {
#if defined(__gfx938__) || defined(__gfx946__)
constexpr bool Do_lpt = Is_causal and Is_GQA;
const int bidh = Do_lpt ? blockIdx.x : blockIdx.y;
const int bidb = Do_lpt ? blockIdx.y : blockIdx.z;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int m_block = Do_lpt ? gridDim.z - 1 - blockIdx.z : blockIdx.x;
flash::compute_fp8_attn_mha_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
if constexpr (Is_causal and !Is_GQA /*MHA causal mask*/) {
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
flash::compute_fp8_attn_mha_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, gridDim.x * 2 - 1 - m_block, warp_id);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 Prefix Prefill (paged KV cache + varlen) for GFX938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout, typename Params>
inline __device__ void compute_fp8_attn_prefix_prefill_1rowblock_gfx938(const Params &params, const int bidb, const int bidh, const int m_block, const int warp_id) {
using Element = typename Kernel_traits::Element;
using Element_k = typename Kernel_traits::Element_k;
using ElementAccum = typename Kernel_traits::ElementAccum;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int WARP_K = 32;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
// Varlen BlockInfo
const flash::BlockInfo<true/*Varlen*/, false/*Is_kvcache*/> binfo(params, bidb);
int max_seq_q_offset = binfo.actual_seqlen_q - m_block * kBlockM;
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k <= 0) return;
// 定义 lds
extern __shared__ int8_t lds[];
int8_t* q_lds = lds + 0;
int8_t* k_lds = lds + 0;
int8_t* v_lds = lds + 0;
// ========================================== 计算 offset (varlen + paged) ===========================================
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block_min = 0;
if constexpr (Is_local) {
n_block_min = max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
}
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
const int window_size_right = Is_local ? params.window_size_right : 0;
n_block_max = min(n_block_max, ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + window_size_right, kBlockN));
}
if (n_block_min >= n_block_max) return;
const int first_block_table_idx = n_block_min * kBlockN / params.page_block_size;
const int first_block_table_offset = n_block_min * kBlockN - first_block_table_idx * params.page_block_size;
const int first_page = block_table[first_block_table_idx];
int64_t row_offset_q, row_offset_k, row_offset_v, row_offset_o;
int row_offset_lse;
if constexpr (Layout == 1) { /*bshd layout*/
row_offset_q = (binfo.sum_s_q + m_block * kBlockM) * int64_t(params.q_row_stride) + params.q_head_stride * bidh;
row_offset_k = int64_t(first_page) * int64_t(params.k_batch_stride) + first_block_table_offset * int64_t(params.k_row_stride) + (bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(first_page) * int64_t(params.v_batch_stride) + first_block_table_offset * int64_t(params.v_row_stride) + (bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = binfo.sum_s_q * int64_t(params.o_head_stride) * params.h + params.o_head_stride * bidh + m_block * kBlockM * params.o_row_stride;
row_offset_lse = bidh * params.total_q + binfo.sum_s_q;
} else { /*bhsd layout*/
row_offset_q = binfo.sum_s_q * int64_t(params.q_row_stride) + bidh * params.q_head_stride + m_block * kBlockM * params.q_row_stride;
row_offset_k = int64_t(first_page) * int64_t(params.k_batch_stride) + first_block_table_offset * int64_t(params.k_row_stride) + (bidh / params.h_h_k_ratio) * params.k_head_stride;
row_offset_v = int64_t(first_page) * int64_t(params.v_batch_stride) + first_block_table_offset * int64_t(params.v_row_stride) + (bidh / params.h_h_k_ratio) * params.v_head_stride;
row_offset_o = binfo.sum_s_q * int64_t(params.o_row_stride) + bidh * params.o_head_stride + m_block * kBlockM * params.o_row_stride;
row_offset_lse = bidh * params.total_q + binfo.sum_s_q;
}
// FP8 descale tensors are broadcast by taking the first scalar value.
// 使用原始指针 (FP8 prefetch函数内部会调用 prepare_for_matrix_load)
Element_k* q_ptr = reinterpret_cast<Element_k*>(params.q_ptr) + row_offset_q;
Element_k* k_ptr = reinterpret_cast<Element_k*>(params.k_ptr) + row_offset_k;
Element_k* v_ptr = reinterpret_cast<Element_k*>(params.v_ptr) + row_offset_v;
ElementAccum* q_descale_ptr = reinterpret_cast<ElementAccum*>(params.q_descale_ptr);
ElementAccum* k_descale_ptr = reinterpret_cast<ElementAccum*>(params.k_descale_ptr);
ElementAccum* v_descale_ptr = reinterpret_cast<ElementAccum*>(params.v_descale_ptr);
ElementAccum q_descale = q_descale_ptr[0];
ElementAccum k_descale = k_descale_ptr[0];
ElementAccum qk_descale = q_descale * k_descale;
ElementAccum softmax_scale = params.scale_softmax * qk_descale;
ElementAccum softmax_scale_log2 = params.scale_softmax_log2 * qk_descale;
ElementAccum v_descale = v_descale_ptr[0];
ElementAccum* softmax_lse_ptr = reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr);
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_M, Element_k>(q_ptr, q_lds, warp_id, params.q_row_stride, max_seq_q_offset);
int lane_id = threadIdx.x & 63;
// 准备寄存器
ElementAccum scores_max[WARP_M / 16];
ElementAccum scores_sum[WARP_M / 16];
vec4_Accum<ElementAccum> acc_o[kHeadDimV / 32][WARP_M / 16][WARP_N / 16];
fp8_attention_initialize<kHeadDimV, WARP_M, WARP_N, ElementAccum>(scores_max, scores_sum, acc_o);
// 从 lds 读取 q 的数据
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64];
load_q_from_lds_to_vgpr<kHeadDim, WARP_M, Element_k>(q_regs, q_lds, warp_id, lane_id);
// ======================================================== Mainloop ======================================================================
int n_masking_steps = 1;
if constexpr (Is_causal) {
const int causal_start_col = m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q;
const int first_mask_block = max(n_block_min, causal_start_col / kBlockN);
n_masking_steps = n_block_max - first_mask_block;
} else if constexpr (Is_local) {
n_masking_steps = min(n_block_max - n_block_min, ceil_div(kBlockM, kBlockN));
}
n_masking_steps = min(max(n_masking_steps, 1), n_block_max - n_block_min);
constexpr bool Assume_valid_rows = !Is_local;
// ======================================================== Prefetch 第一块 K ======================================================================
if (n_block_max > n_masking_steps) {
fp8_prefetch_k_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, binfo.actual_seqlen_k);
}
// ======================================================== 主循环:不需要 causal mask + Prefetch K ============================================================
for (int n_block_loop = n_block_min; n_block_loop < n_block_max - n_masking_steps; ++n_block_loop) {
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// QK gemm(K 数据已在上一轮 prefetch 到 LDS)
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// Prefetch V
fp8_prefetch_v_to_lds<false/*Is_even_MN*/, kBlockN, kHeadDimV, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
if constexpr (Is_local) {
fp8_apply_local_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(
s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN,
params.window_size_left, params.window_size_right, lane_id);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDimV, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// cvt
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
// PV MMAC + Prefetch 下一块 K(paged KV)
const int next_n_block_loop = n_block_loop + 1;
const int block_table_idx_cur = n_block_loop * kBlockN / page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * page_block_size;
const int block_table_idx_next = next_n_block_loop * kBlockN / page_block_size;
const int block_table_offset_next = next_n_block_loop * kBlockN - block_table_idx_next * page_block_size;
const int64_t table_delta = int64_t(block_table[block_table_idx_next] - block_table[block_table_idx_cur]);
const int64_t offset_delta = int64_t(block_table_offset_next - block_table_offset_cur);
Element_k* k_ptr_next = k_ptr
+ table_delta * int64_t(params.k_batch_stride)
+ offset_delta * int64_t(params.k_row_stride);
const int max_seq_kv_offset_next = binfo.actual_seqlen_k - next_n_block_loop * kBlockN;
fp8_pv_gemm_and_prefetch_k_paged<false/*Is_even_MN*/, kHeadDim, kHeadDimV, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(acc_o, p_reg, v_regs, v_lds, k_ptr_next, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset_next);
// 更新 K/V 指针
k_ptr = k_ptr_next;
v_ptr += table_delta * int64_t(params.v_batch_stride)
+ offset_delta * int64_t(params.v_row_stride);
}
// ======================================================== Masking 循环:需要 causal mask,不 Prefetch K ============================================================
int n_block_loop = max(n_block_max - n_masking_steps, n_block_min);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, ++n_block_loop) {
int max_seq_kv_offset = binfo.actual_seqlen_k - n_block_loop * kBlockN;
// 如果主循环没有 prefetch(n_block_max <= n_masking_steps),需要在这里 prefetch K
if (masking_step == 0 && n_block_max <= n_masking_steps) {
fp8_prefetch_k_to_lds<false/*Is_even_MN*/, kHeadDim, WARP_N, Element_k>(k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset);
}
// QK gemm
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16];
fp8_qk_gemm<kBlockN, kHeadDim, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, q_regs, k_lds);
// Prefetch V
fp8_prefetch_v_to_lds<false/*Is_even_MN*/, kBlockN, kHeadDimV, WARP_N, Element_k>(v_ptr, v_lds, warp_id, params.v_row_stride, max_seq_kv_offset);
// Mask
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if constexpr (!Is_causal && !Is_local) {
fp8_apply_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, max_seq_kv_offset, 0, lane_id);
}
// Causal mask
if constexpr (Is_local) {
fp8_apply_local_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(
s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN,
params.window_size_left, params.window_size_right, lane_id);
} else if constexpr (Is_causal) {
fp8_apply_causal_mask<kBlockN, WARP_M, WARP_N, ElementAccum>(s_reg, binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + warp_id * WARP_M, n_block_loop * kBlockN, lane_id);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32];
fp8_softmax_and_schedule_v<Assume_valid_rows, kHeadDimV, kBlockN, WARP_M, WARP_N, WARP_K, Element_k, ElementAccum>(s_reg, scores_max, scores_sum, acc_o, softmax_scale_log2, v_regs, v_lds);
// cvt
union_vec32_fp8 p_reg[WARP_M / 16];
fp8_cvt_f32_to_fp8<kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(s_reg, p_reg);
const int next_n_block_loop = n_block_loop + 1;
if (next_n_block_loop < n_block_max) {
const int block_table_idx_cur = n_block_loop * kBlockN / page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * page_block_size;
const int block_table_idx_next = next_n_block_loop * kBlockN / page_block_size;
const int block_table_offset_next = next_n_block_loop * kBlockN - block_table_idx_next * page_block_size;
const int64_t table_delta = int64_t(block_table[block_table_idx_next] - block_table[block_table_idx_cur]);
const int64_t offset_delta = int64_t(block_table_offset_next - block_table_offset_cur);
Element_k* k_ptr_next = k_ptr
+ table_delta * int64_t(params.k_batch_stride)
+ offset_delta * int64_t(params.k_row_stride);
const int max_seq_kv_offset_next = binfo.actual_seqlen_k - next_n_block_loop * kBlockN;
fp8_pv_gemm_and_prefetch_k_paged<false/*Is_even_MN*/, kHeadDim, kHeadDimV, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(
acc_o, p_reg, v_regs, v_lds, k_ptr_next, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset_next);
k_ptr = k_ptr_next;
v_ptr += table_delta * int64_t(params.v_batch_stride)
+ offset_delta * int64_t(params.v_row_stride);
} else {
fp8_pv_gemm_and_prefetch_k<false/*PrefetchK*/, false/*Is_even_MN*/, kHeadDim, kHeadDimV, kBlockN, WARP_M, WARP_N, Element_k, ElementAccum>(
acc_o, p_reg, v_regs, v_lds, k_ptr, k_lds, warp_id, params.k_row_stride, max_seq_kv_offset);
}
}
// ========================================== rescale by scores_sum ==========================================
ElementAccum lse[WARP_M / 16];
if (params.s_aux_ptr != nullptr) {
const float sink_value = fp8_attention_sink_load(params.s_aux_ptr, params.s_aux_type, bidh);
fp8_attention_sink_apply<kHeadDimV, WARP_M, WARP_N, ElementAccum>(
acc_o, scores_max, scores_sum, softmax_scale, sink_value);
}
if constexpr (Return_softmax) {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDimV, WARP_M, WARP_N, true/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
} else {
fp8_epilogue_rescale_acc_o<Assume_valid_rows, kHeadDimV, WARP_M, WARP_N, false/*StoreLSE*/, ElementAccum>(acc_o, scores_max, scores_sum, lse, softmax_scale, v_descale);
}
// ========================================== lse storation (varlen) ==========================================
if constexpr (Return_softmax) {
fp8_epilogue_store_lse<false/*Is_even_MN*/, WARP_M, ElementAccum>(
softmax_lse_ptr, scores_max, scores_sum, lse, row_offset_lse, binfo.actual_seqlen_q, m_block * kBlockM + warp_id * WARP_M, lane_id);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output<false/*Is_even_MN*/, kBlockM, kHeadDimV, WARP_M, WARP_N, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, params.o_row_stride, binfo.actual_seqlen_q);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_K, bool Return_softmax, bool Has_alibi, int Layout>
__global__ void __launch_bounds__(256, 1) flash_fp8_fwd_prefix_prefill_kernel_gfx938(Flash_fwd_params params) {
#if defined(__gfx938__) || defined(__gfx946__)
// LPT 调度:改变 blockIdx 到 m_block/bidh/bidb 的映射
// causal 模式:blockIdx.x = bidh, blockIdx.y = bidb, blockIdx.z 倒序 = m_block
// 非 causal 模式:blockIdx.x = m_block, blockIdx.y = bidh, blockIdx.z = bidb
constexpr bool Do_lpt = Is_causal;
const int bidh = Do_lpt ? blockIdx.x : blockIdx.y;
const int bidb = Do_lpt ? blockIdx.y : blockIdx.z;
int warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / 64);
int m_block = Do_lpt ? gridDim.z - 1 - blockIdx.z : blockIdx.x;
flash::compute_fp8_attn_prefix_prefill_1rowblock_gfx938<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_K, Return_softmax, Has_alibi, Layout, Flash_fwd_params>(params, bidb, bidh, m_block, warp_id);
#endif
}
} // namespace flash
......@@ -210,7 +210,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_fp8_mla_gfx938(const Param
template<typename Kernel_traits, bool Is_causal, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_fp8_mla_gfx938(const Params &params) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
......@@ -252,7 +252,7 @@ __global__ void flash_mla_convert_query_to_fp8_kernel(
const int nope_row_stride,
const int rope_row_stride,
const int qheads) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
if constexpr (persistent) {
for (int bidb = blockIdx.x; bidb < total_blocks; bidb += gridDim.x) {
// --------------------- nope -------------------------
......
......@@ -108,7 +108,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
const int64_t row_offset_q = bidb * params.q_batch_stride + bidh * params.q_head_stride + m_block * kBlockM * query_seqlen_stride;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
constexpr bool USE_CACHE_SWIZZLE = false;
#else
constexpr bool USE_CACHE_SWIZZLE = true; // for gfx928, cache swizzle have significant influence
......@@ -166,7 +166,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
vec2_Accum<ElementAccum> scores_max[WARP_M/32] = {-INFINITY};
vec2_Accum<ElementAccum> scores_sum[WARP_M/32] = {0};
// 由于当前编译器无法自动生成 v_mov_b64 指令, 主动用 builtin 还会被转译成 v_mov_b32, 因此用内联汇编控制
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
vec4_Accum<ElementAccum> acc_o[(kHeadDimV/kBlockK) * ((WARP_M/32)*(kBlockK/32))][4];
if constexpr (kHeadDimV == 128) { // kHeadDim 128 是主要优化目标
if constexpr (M_MMAC_COUNT == 1) {
......@@ -176,23 +176,15 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
}
__builtin_amdgcn_sched_barrier(0);
} else { // 非 kHeaddim 128, 交给编译器后续的优化了
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (kHeadDimV/kBlockK) * ((WARP_M/32)*(kBlockK/32)); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
#endif
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(pk_zero);
}
}
}
......@@ -418,4 +410,520 @@ inline __device__ void compute_attn_splitkv_int8(const Params &params) {
flash::compute_attn_mha_1rowblock_splitkv_int8<Kernel_traits, Is_training, Is_dropout, Is_causal, Is_local, Is_even_K, Return_softmax, Has_alibi, Split, M_MMAC_COUNT, REUSE_KV_TIMES, Flash_fwd_params>(params, bidb, bidh, warp_id);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MLS-based FP8 Paged Attention, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_acco_reduce_compact_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
constexpr int kReduceBlockK = 32;
constexpr int kReduceRows = M_WARP_COUNT * M_MMAC_COUNT * 16;
const int q_seq_idx = lane_id & 15;
const int lane_dim_offset = (lane_id >> 4) * 4;
const int even_reuse_kv_times = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2 : ((seqlen_q + 1) / 2) * 2;
const bool is_valid_q_lane = q_seq_idx < even_reuse_kv_times;
#pragma unroll
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = (warp_id * kReduceRows + row_idx) * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset;
const int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT
+ k_idx * M_WARP_COUNT + warp_m_idx;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32;
}
}
}
}
__syncthreads();
if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = row_idx * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset + vec_idx;
ElementAccum acc_tmp = acc_o_lds[lds_offset];
#pragma unroll
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp += acc_o_lds[lds_offset + loop * kReduceRows * kReduceBlockK];
}
acc_o_lds[lds_offset] = acc_tmp;
}
}
}
}
}
}
}
__syncthreads();
if (is_valid_q_lane) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int row_idx = warp_m_idx * M_MMAC_COUNT * 16 + min_tile_m * 16 + q_seq_idx;
const int lds_offset = row_idx * kReduceBlockK
+ min_tile_n * 16 + lane_dim_offset;
const int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT
+ k_idx * M_WARP_COUNT + warp_m_idx;
acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
}
__syncthreads();
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_kvcache_apply_mask_local_causal_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q,
const int ngroups, const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 8;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int logical_row = row_idx / ngroups;
const int logical_q = max_seqlen_q / ngroups;
const int col_idx_limit_left = max(0, logical_row + max_seqlen_k - logical_q - window_size_left);
const int col_idx_limit_right = min(max_seqlen_k, logical_row + max_seqlen_k - logical_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] =
(col_idx < col_idx_limit_left || col_idx > col_idx_limit_right)
? -INFINITY
: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template<int kHeadDim, int kBlockM, int WARP_M, int M_MMAC_COUNT, typename Element>
__forceinline__ __device__ void fp8_mha_prefetch_q_to_vgpr_gfx938(
vec4_uint q_addr,
Element* q_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset) {
static_assert(kHeadDim == 128 || kHeadDim == 256);
static_assert(WARP_M == 32);
vec4_uint q_srsrc;
q_srsrc[1] = q_addr[1];
q_srsrc[2] = query_seqlen_stride;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 128; ++k_loop) {
if (warp_id == min_tile_m) {
const int q_row_base = min_tile_m * 16;
const int valid_rows = max_seq_q_offset - q_row_base;
const int safe_q_row_base = valid_rows <= 0 ? 0 : q_row_base;
const int nm_filter = inline_min_max<0, 16>(16 - valid_rows);
q_srsrc[3] = valid_rows >= 16 ? 0 : (nm_filter << 8);
const int64_t row_offset_bytes = int64_t(safe_q_row_base) * int64_t(query_seqlen_stride) * sizeof(Element);
const int64_t dim_offset_bytes = int64_t(k_loop) * 128 * sizeof(Element);
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + row_offset_bytes + dim_offset_bytes);
const int lds_offset_bytes = (min_tile_m * (kHeadDim / 128) + k_loop) * 16 * 128 * sizeof(Element);
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
}
}
}
flash::wait_buffer_data_arrived<true/*sync*/>(0);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 128; ++k_loop) {
const int lds_offset_bytes = (min_tile_m * (kHeadDim / 128) + k_loop) * 16 * 128 * sizeof(Element);
const int q_lds_load_offset = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
DS_READ_MATRIX_64x16_B8(q_lds_load_offset, q_reg[min_tile_m][k_loop * 2 + 0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(q_lds_load_offset + 1024, q_reg[min_tile_m][k_loop * 2 + 1].i32x4, true/*transpose*/)
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockM, int WARP_M, int M_MMAC_COUNT, typename Element>
__forceinline__ __device__ void fp8_mha_prefetch_q_to_vgpr_hdim192_gfx938(
vec4_uint q_addr,
Element* q_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][3],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset) {
static_assert(WARP_M == 32);
constexpr int kLoadBytes = 16 * 128 * sizeof(Element);
vec4_uint q_srsrc;
q_srsrc[1] = q_addr[1];
q_srsrc[2] = query_seqlen_stride;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
if (warp_id == min_tile_m) {
const int q_row_base = min_tile_m * 16;
const int valid_rows = max_seq_q_offset - q_row_base;
const int safe_q_row_base = valid_rows <= 0 ? 0 : q_row_base;
const int nm_filter = inline_min_max<0, 16>(16 - valid_rows);
q_srsrc[3] = valid_rows >= 16 ? 0 : (nm_filter << 8);
const int64_t row_offset_bytes = int64_t(safe_q_row_base) * int64_t(query_seqlen_stride) * sizeof(Element);
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + row_offset_bytes);
const int q_lds_first_offset = (min_tile_m * 2 + 0) * kLoadBytes;
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds, q_srsrc, q_lds_first_offset, 0);
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + row_offset_bytes + 64 * sizeof(Element));
const int q_lds_tail_offset = (min_tile_m * 2 + 1) * kLoadBytes;
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds, q_srsrc, q_lds_tail_offset, 0);
}
}
flash::wait_buffer_data_arrived<true/*sync*/>(0);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int q_lds_first_load = reinterpret_cast<size_t>(q_lds) + (min_tile_m * 2 + 0) * kLoadBytes;
DS_READ_MATRIX_64x16_B8(q_lds_first_load, q_reg[min_tile_m][0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(q_lds_first_load + 1024, q_reg[min_tile_m][1].i32x4, true/*transpose*/)
flash::wait_lds_data_arrived<true/*sync*/>(0);
const int q_lds_tail_load = reinterpret_cast<size_t>(q_lds) + (min_tile_m * 2 + 1) * kLoadBytes;
DS_READ_MATRIX_64x16_B8(q_lds_tail_load + 1024, q_reg[min_tile_m][2].i32x4, true/*transpose*/)
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
__builtin_amdgcn_sched_barrier(0);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_fp8_gfx938(const Params &params, const int bidb, const int bidh, const int warp_id) {
using Element = fp8_e4m3;
using ElementAccum = typename Kernel_traits::ElementAccum;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockN / WARP_N;
constexpr int kHeadDimVSplit = kHeadDimV / HEADDIM_V_SPLIT;
static_assert(kBlockK == 64);
static_assert(kHeadDim == 128 || kHeadDim == 256 || (kHeadDim == 192 && kHeadDimV == 128));
static_assert(kHeadDimVSplit == 128);
flash::SafeDecodeBlockInfo binfo;
binfo.set_params<Params, /*Is_Q_varlen=*/Is_Varlen, /*Is_K_Cumulative=*/false>(params, bidb);
int split_id = 0;
int original_actual_seqlen_k = binfo.actual_seqlen_k;
int partition_size = 0;
if constexpr (Split) {
split_id = blockIdx.y;
if constexpr (Is_Varlen) {
partition_size = splitkv_get_partitionsize_of_fix_numsplits(binfo.actual_seqlen_k, params.num_splits);
binfo.actual_seqlen_k = min(binfo.actual_seqlen_k - split_id * partition_size, partition_size);
} else {
partition_size = params.partition_size;
int num_splits = max(1, floor_div(binfo.actual_seqlen_k, partition_size));
binfo.actual_seqlen_k = (split_id == num_splits - 1)
? binfo.actual_seqlen_k - split_id * partition_size : partition_size;
binfo.actual_seqlen_k = (split_id >= num_splits) ? 0 : binfo.actual_seqlen_k;
if (split_id >= num_splits) return;
}
}
int block_x = blockIdx.x;
const int m_block = block_x / HEADDIM_V_SPLIT;
const int headdim_split_id = block_x & (HEADDIM_V_SPLIT - 1);
int ngroups = 1;
int actual_seqlen_q = binfo.actual_seqlen_q;
if constexpr (Is_Varlen) {
ngroups = params.ngroups;
actual_seqlen_q = binfo.actual_seqlen_q * ngroups;
}
if (m_block * kBlockM >= actual_seqlen_q || binfo.actual_seqlen_k <= 0) return;
extern __shared__ Element fp8_smem[];
constexpr int q_smem_bytes = STAGES * kBlockM * kBlockK * sizeof(Element);
constexpr int kv_smem_bytes = STAGES * kBlockK * WARP_N * sizeof(Element) * WARP_NUM;
constexpr int gemm_smem_bytes = q_smem_bytes > kv_smem_bytes ? q_smem_bytes : kv_smem_bytes;
Element* q_lds = reinterpret_cast<Element*>(fp8_smem);
Element* k_lds = reinterpret_cast<Element*>(fp8_smem);
Element* v_lds = k_lds;
ElementAccum* acc_o_lds = reinterpret_cast<ElementAccum*>(fp8_smem);
ElementAccum* max_lds = reinterpret_cast<ElementAccum*>(
reinterpret_cast<char*>(fp8_smem) + gemm_smem_bytes);
const int query_seqlen_stride = params.q_row_stride;
const int kcache_seqlen_stride = params.k_row_stride;
const int vcache_seqlen_stride = params.v_row_stride;
int n_block_min = 0;
int n_block_max = ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_local) {
const int q_row_start = m_block * kBlockM;
const int q_row_end = min(actual_seqlen_q, (m_block + 1) * kBlockM) - 1;
const int logical_q = Is_Varlen ? actual_seqlen_q / ngroups : actual_seqlen_q;
const int logical_row_start = Is_Varlen ? q_row_start / ngroups : q_row_start;
const int logical_row_end = Is_Varlen ? q_row_end / ngroups : q_row_end;
const int split_seqlen_start = Split ? split_id * partition_size : 0;
const int local_left = max(0, logical_row_start + original_actual_seqlen_k - logical_q - params.window_size_left);
const int local_right = min(original_actual_seqlen_k, logical_row_end + original_actual_seqlen_k - logical_q + params.window_size_right + 1);
const int split_local_left = local_left - split_seqlen_start;
const int split_local_right = local_right - split_seqlen_start;
const int n_block_count = n_block_max;
const int raw_n_block_min = max(0, split_local_left / kBlockN);
const int raw_n_block_max = ceil_div(max(0, split_local_right), kBlockN);
n_block_min = min(max(raw_n_block_min, 0), max(0, n_block_count - 1));
n_block_max = min(max(raw_n_block_max, n_block_min + 1), n_block_count);
}
const int page_block_size = params.page_block_size;
int *block_table = params.block_table + bidb * params.block_table_batch_stride;
const int this_split_seqlen_start = Split ? split_id * partition_size : 0;
block_table = block_table + (Split ? ceil_div(this_split_seqlen_start, page_block_size) : 0);
const int block_table_idx = n_block_min * kBlockN / page_block_size;
const int block_table_offset = n_block_min * kBlockN - block_table_idx * page_block_size;
const int64_t row_offset_k = int64_t(block_table[block_table_idx]) * int64_t(params.k_batch_stride)
+ block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const int64_t row_offset_v = int64_t(block_table[block_table_idx]) * int64_t(params.v_batch_stride)
+ block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int64_t row_offset_q = Is_Varlen
? binfo.sum_s_q * ngroups * int64_t(query_seqlen_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(query_seqlen_stride)
: bidb * int64_t(params.q_batch_stride) + bidh * params.q_head_stride + m_block * kBlockM * int64_t(query_seqlen_stride);
auto q_addr = prepare_for_buffer_load<kHeadDim, Element, false>(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q);
auto k_addr = prepare_for_buffer_load<kHeadDim, Element, false>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_addr = prepare_for_buffer_load<kHeadDimV, Element, false>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v + headdim_split_id * kHeadDimVSplit);
const ElementAccum q_descale = params.q_descale_ptr[0];
const ElementAccum k_descale = params.k_descale_ptr[0];
const ElementAccum v_descale = params.v_descale_ptr[0];
__float2 qk_descale = {q_descale * k_descale, q_descale * k_descale};
int row_offset_lse;
ElementAccum *scores_sum_ptr = nullptr;
ElementAccum *scores_max_ptr = nullptr;
ElementAccum *softmax_lse_ptr = nullptr;
if constexpr (Split) {
int row_offset_scores_split;
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.h * ngroups * params.total_q);
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) + row_offset_lse + row_offset_scores_split;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
row_offset_scores_split = split_id * (params.b * params.h * params.seqlen_q);
scores_sum_ptr = reinterpret_cast<ElementAccum*>(params.scores_sum_ptr) + row_offset_lse + row_offset_scores_split;
scores_max_ptr = reinterpret_cast<ElementAccum*>(params.scores_max_ptr) + row_offset_lse + row_offset_scores_split;
}
} else {
if constexpr (Is_Varlen) {
row_offset_lse = bidh * ngroups * params.total_q + binfo.sum_s_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
} else {
row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
softmax_lse_ptr = reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse;
}
}
constexpr int M_WARP_COUNT = WARP_M / 32;
constexpr int K_WARP_COUNT = kBlockK / 32;
constexpr int N_WARP_COUNT = WARP_N / 32;
constexpr int K_LOOP_COUNT = kHeadDimVSplit / kBlockK;
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT];
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT];
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4];
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64];
attention_initialize<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(scores_max, scores_sum, acc_o);
if constexpr (kHeadDim == 192 && kHeadDimV == 128) {
fp8_mha_prefetch_q_to_vgpr_hdim192_gfx938<kBlockM, WARP_M, M_MMAC_COUNT, Element>(
q_addr, q_lds, q_reg, warp_id, query_seqlen_stride, actual_seqlen_q - m_block * kBlockM);
} else {
fp8_mha_prefetch_q_to_vgpr_gfx938<kHeadDim, kBlockM, WARP_M, M_MMAC_COUNT, Element>(
q_addr, q_lds, q_reg, warp_id, query_seqlen_stride, actual_seqlen_q - m_block * kBlockM);
}
int n_block_loop = n_block_min;
constexpr bool PrefetchK = true;
if constexpr (PrefetchK) {
int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_min * kBlockN;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
}
for (; n_block_loop < n_block_max; ++n_block_loop) {
const int warp_offset_in_seqkv = n_block_loop * kBlockN + warp_id * WARP_N;
const int warp_seqkv_limit = binfo.actual_seqlen_k - n_block_loop * kBlockN;
constexpr bool PrefetchVInQK = (kHeadDim == 128 && K_LOOP_COUNT == 2);
if constexpr (!PrefetchK) {
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, Element>(k_addr, k_lds, warp_id, kcache_seqlen_stride, warp_seqkv_limit);
}
vec4_Accum<ElementAccum> s_reg[M_WARP_COUNT * N_WARP_COUNT][4];
fp8_kvcache_qk_gemm_gfx938<PrefetchVInQK, K_LOOP_COUNT, kHeadDim, kBlockK, WARP_M, WARP_N, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
k_addr, v_addr, k_lds, v_lds, q_reg, s_reg, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit);
if constexpr (!PrefetchVInQK) {
fp8_kvcache_prefetch_v_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element>(
v_addr, v_lds, warp_id, vcache_seqlen_stride, warp_seqkv_limit);
}
fp8_kvcache_apply_descale_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(s_reg, qk_descale);
if constexpr (Is_causal) {
if constexpr (Is_Varlen) {
if constexpr (Is_local) {
fp8_kvcache_apply_mask_local_causal_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, ngroups, params.window_size_left, params.window_size_right);
} else {
kvcache_apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, ngroups);
}
} else {
kvcache_apply_mask_causal_gfx938_mtp<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(
s_reg, warp_offset_in_seqkv + this_split_seqlen_start, original_actual_seqlen_k, m_block * kBlockM, actual_seqlen_q, params.mtp, params.layout);
}
} else {
kvcache_apply_mask_gfx938<vec4_Accum<ElementAccum>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(s_reg, warp_seqkv_limit, warp_id * WARP_N);
}
mla_softmax_rescale_o<Is_causal || Is_local, ElementAccum, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, N_WARP_COUNT, WARP_NUM, M_MMAC_COUNT>(
s_reg, scores_max, scores_sum, acc_o, max_lds, warp_id, params.scale_softmax_log2);
union_vec32_fp8 p_reg[M_MMAC_COUNT];
fp8_kvcache_cvt_f32_to_fp8_gfx938<M_MMAC_COUNT, Element, ElementAccum>(p_reg, s_reg);
const int block_table_idx_cur = n_block_loop * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block_loop * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = min(n_block_max - 1, n_block_loop + 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
const int64_t k_addr_offset = (int64_t(table_diff) * int64_t(params.k_batch_stride) + offset_diff * int64_t(params.k_row_stride)) * sizeof(Element);
fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938<PrefetchK, K_LOOP_COUNT, kBlockK, kBlockN, M_WARP_COUNT, K_WARP_COUNT, WARP_NUM, M_MMAC_COUNT, Element, ElementAccum>(
v_addr, k_addr, v_lds, k_lds, p_reg, acc_o, warp_id, kcache_seqlen_stride, vcache_seqlen_stride, warp_seqkv_limit, k_addr_offset);
*(int64_t*)&v_addr += (int64_t(table_diff) * int64_t(params.v_batch_stride) + offset_diff * int64_t(params.v_row_stride)) * sizeof(Element);
}
if constexpr (PrefetchK) {
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
const int thread_id = threadIdx.x;
const int lane_id = thread_id & 63;
if constexpr (WARP_NUM > 1) {
fp8_kvcache_acco_reduce_compact_gfx938<REUSE_KV_TIMES, K_LOOP_COUNT, K_WARP_COUNT, M_WARP_COUNT, M_MMAC_COUNT, WARP_NUM, ElementAccum>(
acc_o, acc_o_lds, params.seqlen_q, warp_id, lane_id);
}
if (params.s_aux_ptr != nullptr && split_id == 0) {
fp8_kvcache_apply_attention_sink_gfx938<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
acc_o, scores_max, scores_sum, params.s_aux_ptr, params.s_aux_type,
bidh, params.h, ngroups, m_block, kBlockM, lane_id, params.scale_softmax);
}
fp8_kvcache_epilogue_rescale_acco_gfx938<K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(acc_o, scores_sum, v_descale);
if constexpr (Is_Varlen) {
kvcache_epilogue_store_softmax_lse<Is_Varlen, true, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, softmax_lse_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM, params.total_q, params.ngroups);
const int64_t row_offset_o = binfo.sum_s_q * ngroups * int64_t(params.o_row_stride) + bidh * ngroups * params.o_head_stride + headdim_split_id * kHeadDimVSplit + m_block * kBlockM * int64_t(params.o_row_stride);
kvcache_varlen_epilogue_store_output_gfx938<Params, kHeadDimV, kHeadDimVSplit, Split, SplitkvAccumType, ElementAccum, kBlockM, kBlockK, WARP_NUM, K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT>(
acc_o, params, row_offset_o, actual_seqlen_q - m_block * kBlockM, bidb, bidh, m_block, split_id, headdim_split_id, warp_id, lane_id);
} else {
kvcache_epilogue_store_max_sum<Split, true/*Is_16x32*/, M_WARP_COUNT, M_MMAC_COUNT, ElementAccum>(
scores_max, scores_sum, scores_max_ptr, scores_sum_ptr, params.scale_softmax, warp_id, thread_id, lane_id, headdim_split_id, actual_seqlen_q - m_block * kBlockM);
kvcache_epilogue_store_output_gfx938<Params, kHeadDimV, kHeadDimVSplit, true/*alt*/, Split, SplitkvAccumType, ElementAccum, kBlockM, kBlockK, WARP_NUM, K_LOOP_COUNT, M_WARP_COUNT, K_WARP_COUNT, M_MMAC_COUNT>(
acc_o, params, bidb, bidh, m_block, split_id, headdim_split_id, warp_id, lane_id);
}
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_fp8_gfx938(const Params &params) {
#if defined(__gfx938__)
// The block index for the head.
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y; // batch x num_head, num_head first
// The block index for the batch.
const int bidb = Split ? blockIdx.z / params.h : blockIdx.z;
int warp_id_vec = threadIdx.x / 64; // warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_fp8_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
#endif
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
inline __device__ void compute_attn_splitkv_fp8_gfx938_hdim192_v128(const Params &params) {
#if defined(__gfx938__)
static_assert(Kernel_traits::kHeadDim == 192 && Kernel_traits::kHeadDimV == 128);
static_assert(HEADDIM_V_SPLIT == 1);
const int bidh = Split ? blockIdx.z % params.h : blockIdx.y;
const int bidb = Split ? blockIdx.z / params.h : blockIdx.z;
int warp_id_vec = threadIdx.x / 64;
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
flash::compute_attn_1rowblock_splitkv_fp8_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size * 128, Params>(params, bidb, bidh, warp_id);
#endif
}
} // namespace flash
......@@ -32,6 +32,19 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_kernel_gfx938(Flash_fwd_para
flash::compute_attn_gfx938<Kernel_traits, true, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, bool Is_GQA, int Layout>
__global__ void __launch_bounds__(256, 1) flash_fp8_fwd_kernel_gfx938(Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_fp8_attn_gfx938<Kernel_traits, true, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Is_GQA, Layout>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_Varlen, bool Return_softmax, bool Has_alibi, int Layout>
__global__ void __launch_bounds__(256, 1) flash_fwd_kernel_gfx92a(Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local));
flash::compute_attn_gfx92a<Kernel_traits, true, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_Varlen, Return_softmax, Has_alibi, Layout>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, hipStream_t stream) {
......@@ -130,6 +143,101 @@ void run_flash_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream) {
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fp8_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream) {
auto& instance = DeviceProperties<Kernel_traits, FAFUNC::FORWARD, true/*MLS_Enabled*/>::GetInstance();
params.cu_count = instance.cu_count;
const bool is_gqa = params.h != params.h_k;
const bool is_swa = ((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal);
int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h * Kernel_traits::SplitD, params.b);
if (Is_causal) {
if (is_gqa) {
grid = dim3(params.h * Kernel_traits::SplitD, params.b, num_m_block);
} else {
grid.x = (params.seqlen_q + 2 * Kernel_traits::kBlockM - 1) / (2 * Kernel_traits::kBlockM);
}
}
const bool is_varlen = params.cu_seqlens_q != nullptr && params.cu_seqlens_k != nullptr;
const bool is_even_MN = params.seqlen_k % Kernel_traits::kBlockN == 0
&& params.seqlen_q % Kernel_traits::kBlockM == 0
&& (!is_varlen || params.b == 1);
const bool has_alibi = (params.alibi_slopes_ptr not_eq nullptr);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_varlen, Is_Varlen, [&] {
BOOL_SWITCH(is_gqa, Is_GQA, [&] {
constexpr int IsEvenKConst = true;
BOOL_SWITCH(is_swa, Is_local, [&] {
BOOL_SWITCH(has_alibi, Has_Alibi, [&]{
constexpr bool ReturnSoftmaxConst = false;
LAYOUT_SWITCH(params.layout, [&]{
auto kernel = &flash_fp8_fwd_kernel_gfx938<Kernel_traits, Is_dropout, Is_causal, Is_local&&!Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 256, true/*Is_even_K*/, Is_Varlen, ReturnSoftmaxConst && Is_dropout, Has_Alibi, Is_GQA, Layout>;
kernel<<<grid, Kernel_traits::kNThreads, 16 * 1024, stream>>>(params);
});
});
});
});
});
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd_gfx92a(Flash_fwd_params &params, hipStream_t stream) {
auto& instance = DeviceProperties<Kernel_traits, FAFUNC::FORWARD, true/*MLS_Enabled*/>::GetInstance();
params.cu_count = instance.cu_count;
size_t smem_size = instance.lds_size;
const char* fa_debug = std::getenv("FA_DEBUG");
const bool do_fa_debug = fa_debug != nullptr;
if (do_fa_debug) {
printf("[gfx92a launch] gcn_arch=%d cu_count=%d smem_size=%zu q_smem=%zu k_smem=%zu v_smem=%zu seqlen_q=%d seqlen_k=%d h=%d b=%d causal=%d layout=%d\n",
instance.gcn_arch,
params.cu_count,
smem_size,
Kernel_traits::q_smem_size,
Kernel_traits::k_smem_size,
Kernel_traits::v_smem_size,
params.seqlen_q,
params.seqlen_k,
params.h,
params.b,
static_cast<int>(Is_causal),
params.layout);
}
const bool is_swa = ((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal);
int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h * Kernel_traits::SplitD, params.b);
if constexpr (Is_causal) {
grid = dim3(params.h * Kernel_traits::SplitD, params.b, num_m_block);
}
const bool is_varlen = params.cu_seqlens_q != nullptr && params.cu_seqlens_k != nullptr;
const bool is_even_MN = !is_varlen && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool has_alibi = (params.alibi_slopes_ptr not_eq nullptr);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
constexpr int IsEvenKConst = true;
BOOL_SWITCH(is_varlen, Is_Varlen, [&] {
BOOL_SWITCH(is_swa, Is_local, [&] {
BOOL_SWITCH(has_alibi, Has_Alibi, [&]{
constexpr bool ReturnSoftmaxConst = false;
LAYOUT_SWITCH(params.layout, [&]{
auto kernel = &flash_fwd_kernel_gfx92a<Kernel_traits, Is_dropout, Is_causal, Is_local&&!Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 256, true/*Is_even_K*/, Is_Varlen, ReturnSoftmaxConst && Is_dropout, Has_Alibi, Layout>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
});
});
}
template<typename T>
......@@ -210,7 +318,9 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, hipStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// if arch >= 938, new MLS is allowed
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("FA_FWD_NO_MLS") == nullptr) {
if (gcn_arch == 930 and std::getenv("FA_FWD_NO_MLS") == nullptr) {
run_flash_fwd_gfx92a<Flash_fwd_kernel_traits<Headdim, Headdim, 128, 128, 32, 32, 32, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else if (gcn_arch >= 938 and std::getenv("FA_FWD_NO_MLS") == nullptr) {
if (params.qkvheaddim_compute == 96) {
if (params.qkvheaddim_tail_tile16 == 1)
run_flash_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, Headdim, 128, 128, 32, 32, 32, 2, false, false, T, T, T, 64, 96, 96, 1>, Is_dropout, Is_causal>(params, stream);
......@@ -234,6 +344,20 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, hipStream_t stream) {
}
}
template<typename T>
void run_fp8_mha_fwd_hdim128(Flash_fwd_params &params, hipStream_t stream) {
constexpr int Headdim = 128;
constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
int gcn_arch = getArch();
if (gcn_arch >= 938) {
run_flash_fp8_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, Headdim, 128, 128, 32, 32, 32, 2, false, false, T, Float16, fp8_e4m3>, Is_dropout, Is_causal>(params, stream);
} else {
printf("\x1b[31mfp8 is not supported in this arch!\033[0m\n");
}
});
}
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, hipStream_t stream) {
......@@ -412,7 +536,7 @@ void run_flash_fwd_prefix_prefill_launcher(Flash_fwd_params &params, hipStream_t
void (*kernel)(Flash_fwd_params);
constexpr bool IsEvenMNConst = false;
BOOL_SWITCH(params.window_size_left > 0 and params.window_size_right >= 0, Is_local, [&]{
BOOL_SWITCH(!Is_causal && params.window_size_left > 0 and params.window_size_right >= 0, Is_local, [&]{
kernel = &flash_fwd_prefix_prefill_kernel<Kernel_traits, false/*dropout*/, Is_causal, Is_local, IsEvenMNConst, false/*return softmax*/, false/*Has_Alibi*/, false/*Is_GQA*/, 1/*layout*/, Flash_fwd_params>;
});
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
......@@ -428,7 +552,24 @@ void run_flash_fwd_prefix_prefill_gfx938_launcher(Flash_fwd_params &params, hipS
dim3 grid(params.h, params.b, num_m_block);
constexpr bool IsEvenMNConst = false;
auto kernel = &flash_fwd_prefix_prefill_gfx938_kernel<Kernel_traits, false, false/*dropout*/, Is_causal, false/*Is_local*/, IsEvenMNConst, true, false/*return softmax*/, false/*Has_Alibi*/, 1/*layout*/, Flash_fwd_params>;
const bool is_local = !Is_causal && params.window_size_left > 0 && params.window_size_right >= 0;
BOOL_SWITCH(is_local, Is_local, [&] {
auto kernel = &flash_fwd_prefix_prefill_gfx938_kernel<Kernel_traits, false, false/*dropout*/, Is_causal, Is_local, IsEvenMNConst, true, false/*return softmax*/, false/*Has_Alibi*/, 1/*layout*/, Flash_fwd_params>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_fwd_prefix_prefill_gfx92a_launcher(Flash_fwd_params &params, hipStream_t stream) {
auto& instance = DeviceProperties<Kernel_traits, FAFUNC::FORWARD, true/*MLS_enabled*/>::GetInstance();
size_t smem_size = instance.lds_size;
int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(params.h, params.b, num_m_block);
constexpr bool IsEvenMNConst = false;
auto kernel = &flash_fwd_prefix_prefill_gfx92a_kernel<Kernel_traits, false, false/*dropout*/, Is_causal, false/*Is_local*/, IsEvenMNConst, true, false/*return softmax*/, false/*Has_Alibi*/, 1/*layout*/, Flash_fwd_params>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
}
......@@ -437,9 +578,18 @@ template<typename T, int Headdim, int HeaddimV>
void run_flash_fwd_prefix_prefill(Flash_fwd_params &params, hipStream_t stream) {
// is_causal = false, used in cascade attention
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (getArch() >= 938 and std::getenv("FA_FWD_NO_MLS") == nullptr and ((Headdim == 128 and HeaddimV == 128) or (Headdim == 192 and HeaddimV == 128))) {
const bool use_mls_prefix = std::getenv("FA_FWD_NO_MLS") == nullptr && ((Headdim == 128 and HeaddimV == 128) or (Headdim == 192 and HeaddimV == 128));
const int gcn_arch = getArch();
if (gcn_arch == 930 and use_mls_prefix) {
if constexpr (Headdim == 192)
run_flash_fwd_prefix_prefill_gfx92a_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 64, 32, 32, 32, 2, false, false, T>, Is_causal>(params, stream);
else
run_flash_fwd_prefix_prefill_gfx92a_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T>, Is_causal>(params, stream);
} else if (gcn_arch >= 938 and use_mls_prefix) {
if constexpr (Headdim == 192)
run_flash_fwd_prefix_prefill_gfx938_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 64, 32, 32, 32, 2, false, false, T>, Is_causal>(params, stream);
else if (params.page_block_size == 64)
run_flash_fwd_prefix_prefill_gfx938_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 64, 32, 32, 32, 2, false, false, T>, Is_causal>(params, stream);
else
run_flash_fwd_prefix_prefill_gfx938_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T>, Is_causal>(params, stream);
} else {
......@@ -486,3 +636,49 @@ void run_int8_flash_fwd_prefix_prefill(Flash_fwd_params &params, hipStream_t str
run_int8_flash_fwd_prefix_prefill_launcher<Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T, Float16, int8_t>, Is_causal>(params, stream);
});
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_fp8_fwd_prefix_prefill_launcher_gfx938(Flash_fwd_params &params, hipStream_t stream) {
constexpr bool NeedsWideFp8MlsLds = Kernel_traits::kHeadDim > 128 || Kernel_traits::kHeadDimV > 128;
size_t smem_size = NeedsWideFp8MlsLds ? 32 * 1024 : 16 * 1024;
int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid = Is_causal ? dim3(params.h, params.b, num_m_block)
: dim3(num_m_block, params.h, params.b);
constexpr bool Has_Alibi = false;
const bool is_local = !Is_causal && params.window_size_left > 0 && params.window_size_right >= 0;
BOOL_SWITCH(params.softmax_lse_ptr != nullptr, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(is_local, IsLocalConst, [&] {
LAYOUT_SWITCH(params.layout, [&]{
auto kernel = &flash_fp8_fwd_prefix_prefill_kernel_gfx938<Kernel_traits, true/*Is_training*/, false/*Is_dropout*/, Is_causal, IsLocalConst, true/*Is_even_K*/, ReturnSoftmaxConst, Has_Alibi, Layout>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
}
template<typename T, int Headdim, int HeaddimV>
void run_fp8_flash_fwd_prefix_prefill(Flash_fwd_params &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.page_block_size == 64) {
if (params.seqlen_q <= 128) {
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938<
Flash_fwd_kernel_traits<Headdim, HeaddimV, 64, 64, 32, 32, 32, 2, false, false, T, Float16, fp8_e4m3>, Is_causal>(params, stream);
} else {
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938<
Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 64, 32, 32, 32, 2, false, false, T, Float16, fp8_e4m3>, Is_causal>(params, stream);
}
} else {
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938<
Flash_fwd_kernel_traits<Headdim, HeaddimV, 128, 128, 32, 32, 32, 2, false, false, T, Float16, fp8_e4m3>, Is_causal>(params, stream);
}
});
} else {
printf("\x1b[31mfp8 prefix_prefill is not supported in this arch!\033[0m\n");
}
}
......@@ -9,6 +9,15 @@
#include "flash_fwd_kernel.h"
#include "flash_singleton.h"
#include "assert.h"
#include <string>
static inline bool hg_pa_is_gfx92a(const std::string &gcn_arch_name) {
return gcn_arch_name.rfind("gfx92a", 0) == 0;
}
static inline int hg_pa_runtime_gfx_arch_id(const std::string &gcn_arch_name) {
return hg_pa_is_gfx92a(gcn_arch_name) ? 930 : std::stoi(gcn_arch_name.substr(3, 3));
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_Varlen, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, bool Is_GQA, bool Is_softcap, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, bool Append_KV>
......@@ -23,18 +32,32 @@ __global__ void __launch_bounds__(256,1) flash_fwd_splitkv_int8_kernel(Flash_fwd
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_tile16x32_kernel(Params params) {
flash::compute_attn_splitkv_tile16x32<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
flash::compute_attn_splitkv_tile16x32<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
flash::compute_attn_splitkv_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_fp8_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_fp8_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel(Params params) {
flash::compute_attn_splitkv_fp8_gfx938_hdim192_v128<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Is_monopolize, bool Split, int M_MMAC_COUNT, int HEADDIM_V_SPLIT, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_gfx92a_kernel(Params params) {
flash::compute_attn_splitkv_gfx92a<Kernel_traits, Is_causal, Is_Varlen, Is_monopolize, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT>(params);
}
template<typename Kernel_traits, const bool Tail, typename Params>
void run_splitkv_reduce(Params &params, hipStream_t stream) {
......@@ -157,7 +180,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, hipStream_t stream) {
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
int gcn_arch = hg_pa_runtime_gfx_arch_id(gcn_arch_name);
#endif
const size_t smem_size = gcn_arch > 928 ? required_smem_size: size_t(64 * 1024);
if (std::getenv("FA_DEBUG") != nullptr) {
......@@ -224,7 +247,7 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params &params, hipStream_t strea
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
int gcn_arch = hg_pa_runtime_gfx_arch_id(gcn_arch_name);
#endif
const size_t smem_size = gcn_arch > 928 ? size_t(std::max<size_t>(32 * 1024, required_smem_size)): size_t(64 * 1024);
if (std::getenv("FA_DEBUG") != nullptr) {
......@@ -245,14 +268,16 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params &params, hipStream_t strea
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.window_size_left > 0 and params.window_size_right >= 0) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, false, true/*Is_local*/, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, 0>;
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, true/*Is_local*/, M_MMAC_COUNT, HEADDIM_V_SPLIT, 0>;
});
});
} else if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
......@@ -261,7 +286,9 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params &params, hipStream_t strea
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, Is_causal, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
});
......@@ -311,29 +338,157 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream)
const size_t required_smem_size = std::max(smem_for_acc, std::max(smem_for_gemm, smem_for_max));
const size_t smem_size = size_t(std::max<size_t>(32 * 1024, required_smem_size));
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld\n",
smem_for_max, smem_for_acc, q_smem_size, k_smem_size, v_smem_size, smem_for_gemm, required_smem_size, smem_size);
printf("grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
}
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = Kernel_traits::kHeadDimV == 256 ? 2 : 1;
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
BOOL_SWITCH(params.mtp != 1, Is_causal, [&]{
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
kernel = &flash_fwd_splitkv_gfx938_kernel<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd_gfx92a(Flash_fwd_params &params, hipStream_t stream) {
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// decide shared memory
size_t smem_size = 32768;
if (grid.x * grid.y * grid.z <= params.cu_count) { smem_size = 65536; }
if (std::getenv("PA_NO_ALL_LDS") != nullptr) { smem_size = 32768; }
if (std::getenv("PA_USE_ALL_LDS") != nullptr) { smem_size = 65536; }
// output some details
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_size: %ld\n", smem_size);
printf("grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
}
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
constexpr int HEADDIM_V_SPLIT = Kernel_traits::kHeadDimV == 256 ? 2 : 1;
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.mtp == 1) {
BOOL_SWITCH(params.mtp != 1, Is_causal, [&] {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(smem_size == 65536, Is_monopolize, [&]{
kernel = &flash_fwd_splitkv_gfx92a_kernel<Kernel_traits, Is_causal, Is_Varlen, Is_monopolize, Split, M_MMAC_COUNT, HEADDIM_V_SPLIT>;
});
});
});
});
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename Kernel_traits>
void run_fp8_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
constexpr int kReduceBlockK = 32;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_for_acc = Kernel_traits::kBlockM * WARP_NUM * kReduceBlockK * sizeof(float);
const size_t q_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockM * Kernel_traits::kBlockK * sizeof(Float8_e4m3_t);
const size_t k_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockK * Kernel_traits::kWaveN * sizeof(Float8_e4m3_t) * WARP_NUM;
const size_t v_smem_size = k_smem_size;
const size_t smem_for_gemm = std::max(q_smem_size, std::max(k_smem_size, v_smem_size));
constexpr bool IsFp8PA192x128 = Kernel_traits::kHeadDim == 192 && Kernel_traits::kHeadDimV == 128;
const size_t required_smem_size = IsFp8PA192x128
? std::max(smem_for_acc, std::max(smem_for_gemm, smem_for_max))
: std::max(smem_for_acc, smem_for_gemm + smem_for_max);
const size_t smem_size_floor = IsFp8PA192x128 ? size_t(32 * 1024) : size_t(17 * 1024);
const size_t smem_size = size_t(std::max(smem_size_floor, required_smem_size));
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld\n",
smem_for_max, smem_for_acc, q_smem_size, k_smem_size, v_smem_size, smem_for_gemm, required_smem_size, smem_size);
}
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = Kernel_traits::kHeadDimV == 256 ? 2 : 1;
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.window_size_left > 0 && params.window_size_right >= 0) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
if constexpr (IsFp8PA192x128) {
kernel = &flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, true/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
} else {
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, true/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
}
});
});
});
} else if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_gfx938_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
constexpr bool Is_local = false;
if constexpr (IsFp8PA192x128) {
kernel = &flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
} else {
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
}
});
});
});
......@@ -342,7 +497,11 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream)
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
if constexpr (IsFp8PA192x128) {
kernel = &flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
} else {
kernel = &flash_fwd_splitkv_fp8_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
}
});
});
});
......@@ -366,8 +525,34 @@ template<typename T, int Headdim, int HeaddimV>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream) {
// decide whether commonly used headdims
const bool is_commonly_used = params.d % 64 == 0 and params.d_value % 64 == 0/*prefetch 2 32x32 blocks along headdim*/;
// For latest archs, mls can be applied for headdim 128
if ((getArch() >= 938) and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used) {
// For latest archs, MLS can be applied for the common decode head dims.
int arch_id = getArch();
constexpr bool use_gfx938_mls =
(Headdim == 128 and HeaddimV == 128) or
(Headdim == 192 and HeaddimV == 128) or
(Headdim == 256 and HeaddimV == 256);
if constexpr (use_gfx938_mls) {
const bool is_local = params.window_size_left > 0 && params.window_size_right >= 0;
const bool use_mls_mask = params.is_e4m3 ? true : params.is_causal;
if ((arch_id >= 938) and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used and use_mls_mask) {
if (params.is_e4m3) {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
run_fp8_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 64, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
if (params.page_block_size == 64) {
constexpr int kBlockN = 64;
run_fp8_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 64, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
} else {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_fp8_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 64, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
}
#endif
} else {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
......@@ -375,13 +560,26 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream)
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
if (params.page_block_size == 64) {
constexpr int kBlockN = 64;
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
} else {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
}
#endif
}
return;
} else if (arch_id == 930 and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used) {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fwd_gfx92a<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
return;
}
}
// For MHA-fma, headdim = 128
else if (params.seqlen_q == 1 and !params.seqlenq_ngroups_swapped and Headdim == 128 and HeaddimV == 128 and std::getenv("PA_USE_FMA") != nullptr) {
if (params.seqlen_q == 1 and !params.seqlenq_ngroups_swapped and Headdim == 128 and HeaddimV == 128 and std::getenv("PA_USE_FMA") != nullptr) {
constexpr int kBlockN = 128;
run_flash_splitkv_fwd_mha<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32/*kBlockM*/, kBlockN, 32/*kBlockK*/, 32, 32, 2/*STAGES*/, false, false, T, float> >(params, stream);
}
......@@ -393,9 +591,14 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream)
run_flash_splitkv_fwd_tile16x32<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
if (params.page_block_size == 64) {
constexpr int kBlockN = 64;
run_flash_splitkv_fwd_tile16x32<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
} else {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fwd_tile16x32<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
}
#endif
} else {
// Decide whether compile all page block sizes
......@@ -407,6 +610,15 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream)
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T> >(params, stream);
});
#else
if (params.page_block_size == 64) {
constexpr int kBlockN = 64;
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
if (params.splitkv_use_fp32_as_accum) {
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, float> >(params, stream);
} else {
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T> >(params, stream);
}
} else {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
......@@ -415,6 +627,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream)
} else {
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T> >(params, stream);
}
}
#endif
}
}
......@@ -446,7 +659,7 @@ void run_int8_flash_splitkv_fwd(Flash_fwd_params &params, hipStream_t stream) {
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
int gcn_arch = hg_pa_runtime_gfx_arch_id(gcn_arch_name);
#endif
const size_t smem_size = gcn_arch > 928 ? required_smem_size: size_t(48 * 1024);
// printf("smem_for_max: %ld | smem_for_acc: %ld | smem_for_gemm: %ld | needed smem_size: %ld | smem_size: %ld\n", smem_for_max, smem_for_acc, smem_for_gemm, required_smem_size, smem_size);
......
......@@ -102,7 +102,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
......@@ -289,7 +289,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
tx_accum[t] = lds[tx * tx_float_count + t];
tx_accum[t + 1] = lds[tx * tx_float_count + t + 1];
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
......@@ -445,7 +445,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
......@@ -588,7 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
}
// cvt
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
......
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_<BFloat16, 128, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_mha_fwd_hdim128<BFloat16>(params, stream);
#endif
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_<Float16, 128, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_mha_fwd_hdim128<Float16>(params, stream);
#endif
}
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<BFloat16, 128, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<BFloat16, 128, 128>(params, stream);
#endif
}
\ No newline at end of file
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<Float16, 128, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<Float16, 128, 128>(params, stream);
#endif
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment