Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -234,7 +234,7 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
......
......@@ -401,13 +401,11 @@ __forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
} else {
// gfx938 m_ab = 0的gemm想要复用m_ab = 1的LDS数据
......
......@@ -410,13 +410,11 @@ __forceinline__ __device__ void gemm_tt_kq_gfx938(
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
}
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
......
......@@ -117,13 +117,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
for(int m_loop = 0; m_loop < M / 128; ++m_loop) {
......@@ -147,13 +144,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
if(n_loop < N / 32) {
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
......@@ -167,36 +161,20 @@ inline __device__ void prefetch_to_vgpr_gfx938(
if(trans){
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
} else {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
}
lgkmcnt_wait(0);
......@@ -246,13 +224,11 @@ inline __device__ void prefetch_to_lds_gfx938(
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset);
//计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
int lds_offset = (loop_warp * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset;
int lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset;
if (trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset, 0);
}
}
}
......
......@@ -327,7 +327,7 @@ inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_
auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp =
hcu_pk_fma_f32(
__builtin_hcu_pk_fma_f32(
vec2_tensor,
vec2_scale,
vec2_max_scaled);
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -386,6 +386,8 @@ template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_prefix_prefill_
template<typename T, int Headdim, int HeaddimV> void run_int8_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_fp8_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_prefix_prefill_dispatch_(Flash_fwd_mla_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -30,7 +30,7 @@ __forceinline__ __device__ void fwd_epilugue_rescale_acco(
const int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + ni * (WARP_M / 32) + mi;
#if defined(__gfx936__) || defined(__gfx938__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
......@@ -108,58 +108,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
#if defined(__gfx938__)
constexpr bool Is_Interleaved_ = Is_Interleaved and kHeadDimV == 128;
#else
constexpr bool Is_Interleaved_ = Is_Interleaved;
#endif
if constexpr (Is_Interleaved_) {
#if defined(__gfx938__)
#pragma unroll
for(int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for(int warp_m_idx = 0; warp_m_idx < (WARP_M / 32); ++warp_m_idx) {
#pragma unroll
for(int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int tile32x32_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
int s_offset = k_loop * kBlockK;
int seqlen_q_offset = (warp_id * WARP_M + warp_m_idx * 32 + pv_lane_seq_idx * 2 + min_tile_m);
int v_offset = seqlen_q_offset * seqlen_o_stride + pv_lane_head_dim_idx * 8;
union_vec4_f16x2<Element> v_data;
#pragma unroll
for(int vec_index = 0; vec_index < 4; ++vec_index) {
constexpr bool is_bf16 = std::is_same<Element, bhalf_t>::value;
v_data.f16x2[vec_index][0] = DownCast<ElementAccum, Element, is_bf16>(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data.f16x2[vec_index][1] = DownCast<ElementAccum, Element, is_bf16>(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
}
auto lds = (__attribute__((address_space(3))) float*)(0);
int lds_write_offset = (warp_id * 512 + pv_lane_seq_idx * 16 + pv_lane_head_dim_idx * 4 + pv_lane_seq_idx * 4) * 4;
__builtin_amdgcn_sched_barrier(0);
inlineasm_ds_write_b128(lds_write_offset, v_data.f32);
flash::wait_lds_data_arrived<false>(0);
#pragma unroll
for(int vec_index = 0; vec_index < 2; ++vec_index) {
int lds_load_offset = (warp_id * 512 + pv_lane_seq_idx * 16 + vec_index * 8 + pv_lane_head_dim_idx + pv_lane_seq_idx * 4) * 4;
asm volatile("ds_read2_b32 %0, %1 offset0:0 offset1:%2\n":: "v"(v_data.data[vec_index]), "v"(lds_load_offset), "B"(4));
}
flash::wait_lds_data_arrived<false>(0);
if constexpr (Is_even_MN) {
*(vec4_fp32*)(o_ptr + v_offset + s_offset + k_tile_idx * 32) = v_data.f32;
} else {
if(m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(vec4_fp32*)(o_ptr + v_offset + s_offset + k_tile_idx * 32) = v_data.f32;
}
}
}
}
}
}
#else
if constexpr (Is_Interleaved) {
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
......@@ -173,11 +122,13 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
const int pv_tile_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
const int mmac_id = min_tile_m + min_tile_n * 2;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + min_tile_m * 16 + pv_lane_seq_idx;
// prepare for store
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
......@@ -191,8 +142,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
}
}
}
#endif
} // brace, to control vgpr usage
} else {
auto gO = prepare_for_buffer_load<kHeadDimV, Element, TcpSwizzle>(o_ptr);
#pragma unroll
......
File mode changed from 100644 to 100755
#include "numeric_types.h"
#include "intrinsic.h"
template<bool AssumeValidRows, int kHeadDim, int WARP_M, int WARP_N, bool StoreLSE, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_rescale_acc_o(
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum lse[WARP_M / 16],
ElementAccum softmax_scale,
ElementAccum v_descale
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
ElementAccum sum = scores_sum[m_idx];
if constexpr (StoreLSE) {
lse[m_idx] = (sum == 0.f || sum != sum) ? INFINITY : __llvm_fma_f32(scores_max[m_idx], softmax_scale, __logf(sum));
}
ElementAccum total_rescale;
if constexpr (AssumeValidRows) {
total_rescale = v_descale / sum;
} else {
total_rescale = (sum == 0.f || sum != sum) ? 0.f : v_descale / sum;
}
__float2 total_scale_pair;
total_scale_pair[0] = total_rescale;
total_scale_pair[1] = total_rescale;
// __float2 inv_sum_pair;
// inv_sum_pair[0] = 1.0f / sum;
// inv_sum_pair[1] = inv_sum_pair[0];
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 16; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
acc_o[k_loop][m_idx][n_idx].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[0], total_scale_pair);
acc_o[k_loop][m_idx][n_idx].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[1], total_scale_pair);
}
}
}
}
template<bool Is_even_MN, int WARP_M, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_store_lse(
// ElementAccum* scores_max_ptr,
// ElementAccum* scores_sum_ptr,
ElementAccum* softmax_lse_ptr,
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum lse[WARP_M / 16],
int row_offset_lse, /*(bidb * h + bidh) * actual_seqlen_q*/
int actual_seqlen_q,
int wave_row_offset, /*m_block * kBlockM + warp_id * WARP_M*/
int lane_id
) {
if (lane_id < 16) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int lse_row_id = row_offset_lse + wave_row_offset + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);;
// scores_max_ptr[lse_row_id] = scores_max[m_idx];
// scores_sum_ptr[lse_row_id] = scores_sum[m_idx];
if (lse_row_id-row_offset_lse < actual_seqlen_q){
softmax_lse_ptr[lse_row_id] = lse[m_idx];
}
}
}
}
template<bool Is_even_MN, int kBlockM, int kHeadDim, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_store_output(
Element* acc_o_ptr,
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
int m_block,
int warp_id,
int lane_id,
int o_row_stride,
int actual_seqlen_q
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
int row_idx = warp_id * WARP_M + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);
int col_idx = pv_loop * 32 + mmac_id * 16 + (lane_id >> 4) * 4;
int offset = row_idx * o_row_stride + col_idx;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_loop][m_idx][mmac_id].f32x2[vec_index]);
// v_data = __builtin_hcu_cvt_pk_f16_f32(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2], acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1], false/*clamp*/, 0/*omod*/);
// v_data[0] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2]);
// v_data[1] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1]);
}
if constexpr (Is_even_MN) {
*(union_vec2_f16x2<Element>*)(acc_o_ptr + offset) = v_data;
} else if (m_block * kBlockM + row_idx < actual_seqlen_q) {
*(union_vec2_f16x2<Element>*)(acc_o_ptr + offset) = v_data;
}
// *(vec4_fp32*)(acc_o_ptr + offset) = acc_o[pv_loop][m_idx][mmac_id].f32;
}
}
}
}
#include "fp8_qk_gemm_utils_mls_ds.h"
#include "static_switch.h"
// PrefetchK=false 版本:不 prefetch K,不需要额外参数
template<bool PrefetchK, bool Is_even_MN, int kHeadDim, int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_pv_gemm_and_prefetch_k(
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16],
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32],
int8_t* v_lds,
Element*& k_ptr,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset
) {
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// mmac stream
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].f32 = mmac_4interleave_fp8<int8_t, ElementAccum>(p_reg[m_idx].i8x8[k_loop], v_regs[k_loop][pv_loop].i8x8[mmac_id], acc_o[pv_loop][m_idx][mmac_id].f32);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
// PrefetchK=true 版本:在 PV MMAC 期间 prefetch 下一块 K(paged KV)
template<bool Is_even_MN, int kHeadDim, int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_pv_gemm_and_prefetch_k_paged(
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16],
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32],
int8_t* v_lds,
Element* k_ptr_next,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset_next
) {
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// Prefetch 下一块 K 到 LDS(与 MMAC 重叠)
__builtin_amdgcn_sched_barrier(0);
fp8_prefetch_k_to_lds<Is_even_MN, kHeadDim, WARP_N, Element>(k_ptr_next, k_lds, warp_id, k_row_stride, max_seq_kv_offset_next);
__builtin_amdgcn_sched_barrier(0);
// mmac stream
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].f32 = mmac_4interleave_fp8<int8_t, ElementAccum>(p_reg[m_idx].i8x8[k_loop], v_regs[k_loop][pv_loop].i8x8[mmac_id], acc_o[pv_loop][m_idx][mmac_id].f32);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic_mls_ds.h"
template<bool Is_even_MN, int kBlockN, int kHeadDim, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_prefetch_v_to_lds(
Element* v_ptr,
int8_t* v_lds,
int warp_id,
int v_row_stride,
int max_seq_kv_offset
) {
// 准备 MLS 寄存器, 填充 stride
vec4_uint v_root = prepare_for_matrix_load<kHeadDim, Element>(v_ptr);
vec4_uint v_srsrc;
v_srsrc[0] = v_root[0];
v_srsrc[1] = v_root[1];
v_srsrc[2] = v_row_stride; // stride
v_srsrc[3] = 0x00000;
// 4 个 wave 直接全量预取
int v_lds_write_bytes = warp_id * WARP_N * kHeadDim * sizeof(Element);
// 每次读取 32x128 的数据
// tile1: 行 [warp_id*32, warp_id*32+16)
// 整个 tile 被 filter 时保留一行合法 V,避免 0 * NaN 污染 PV。
int nm_filter_warp0_tile1 = inline_min_max<0, 16>(16 - max_seq_kv_offset);
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_kv_offset);
v_srsrc[0] = (nm_filter == 16) ? v_root[0] : v_root[0] + (warp_id * 2) * 16 * v_row_stride * sizeof(Element);
nm_filter = (nm_filter == 16) ? min(nm_filter_warp0_tile1, 15) : nm_filter;
v_srsrc[3] = v_srsrc[3] + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
flash::wait_all_warp_arrived();
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds+v_lds_write_bytes, 0, true, false, false, false, false);
// tile2: 行 [warp_id*32+16, warp_id*32+32)
int nm_filter_warp0_tile2 = inline_min_max<0, 16>(32 - max_seq_kv_offset);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_kv_offset);
v_srsrc[0] = (nm_filter == 16) ? v_root[0] : v_root[0] + (warp_id * 2 + 1) * 16 * v_row_stride * sizeof(Element);
nm_filter = (nm_filter == 16) ? min(nm_filter_warp0_tile2, 15) : nm_filter;
v_srsrc[3] = ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds+v_lds_write_bytes + (128 * 16 >> 1), 0, true, false, false, false, false);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment