Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
......@@ -9,6 +9,7 @@ __attribute__((weak)) int getArch() {
auto hipResult = hipGetDeviceProperties(&props, 0);
std::string gcn_arch_name(props.gcnArchName);
gcn_arch_name = gcn_arch_name.substr(3, 3);
if (gcn_arch_name == "92a") gcn_arch_name = "930";
int gcn_arch = std::stoi(gcn_arch_name);
return gcn_arch;
}
......@@ -38,13 +39,8 @@ private:
DeviceProperties() { // 可以在这里给内部变量赋初始值
hipDeviceProp_t props;
auto hipResult = hipGetDeviceProperties(&props, 0);
#ifdef ROCM_5_7
this->gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
this->gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
#endif
this->cu_count = props.multiProcessorCount;
this->gcn_arch = getArch();
const char* fa_debug = std::getenv("FA_DEBUG");
bool do_fa_debug = fa_debug != nullptr;
......@@ -55,16 +51,19 @@ private:
const size_t q_smem_size = run_new_mls ? least_required_size: Kernel_traits::q_smem_size;
const size_t k_smem_size = run_new_mls ? least_required_size: Kernel_traits::k_smem_size * 2;
const size_t v_smem_size = run_new_mls ? least_required_size: Kernel_traits::v_smem_size * 2;
if (gcn_arch == 928 or gcn_arch == 936 or gcn_arch == 938) {
if (gcn_arch == 928 or gcn_arch == 936 or gcn_arch == 938 or gcn_arch == 946) {
this->lds_size = run_new_mls ? std::max(q_smem_size, std::max(v_smem_size, k_smem_size)): std::max(q_smem_size, v_smem_size + k_smem_size);
}
else if (gcn_arch == 930) {
this->lds_size = 32 * 1024;
}
if (do_fa_debug and std::strcmp(fa_debug, "2")) {
printf("gcn_arch: %d\nq_smem_size: %ld\nk_smem_size: %ld\nv_smem_size: %ld\nshared memory size: %ld\ncu count: %d\n", this->gcn_arch, q_smem_size, k_smem_size, v_smem_size, this->lds_size, this->cu_count);
}
} else if constexpr (Func == FAFUNC::BACKWARD) {
this->lds_size = 32 * 1024;
if(this->gcn_arch >= 936 && Kernel_traits::kHeadDim == 128){
if(this->gcn_arch == 936) {
if(this->gcn_arch >= 936 && Kernel_traits::kHeadDim <= 128){
if(this->gcn_arch == 936 || this->gcn_arch == 938) {
this->lds_size = 21 * 1024;
} else {
this->lds_size = 16 * 1024;
......
#include "numeric_types.h"
#include "intrinsic.h"
template<int WARP_M, int kBlockK, int kHeadDimV, typename ElementAccum>
__forceinline__ __device__ void fwd_apply_attention_sink(
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4],
vec2_Accum<ElementAccum> scores_max[WARP_M / 32],
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32],
const ElementAccum scale_softmax,
const float sink_value) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const ElementAccum old_scaled_max = scores_max[mi].f32[min_tile_m] * scale_softmax;
const ElementAccum new_scaled_max = max(old_scaled_max, ElementAccum(sink_value));
const ElementAccum old_rescale = __expf(old_scaled_max - new_scaled_max);
scores_sum[mi].f32[min_tile_m] = scores_sum[mi].f32[min_tile_m] * old_rescale + __expf(ElementAccum(sink_value) - new_scaled_max);
scores_max[mi].f32[min_tile_m] = new_scaled_max / scale_softmax;
__float2 old_rescale_pair = {old_rescale, old_rescale};
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int mmac_id = min_tile_n * 2 + min_tile_m;
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + ni * (WARP_M / 32) + mi;
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] =
__builtin_hcu_pk_mul_f32(acc_o[pv_tile_id][mmac_id].u64[vec_id], old_rescale_pair);
}
#else
#pragma unroll
for (int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[pv_tile_id][mmac_id].f32[vec_id] *= old_rescale;
}
#endif
}
}
}
}
}
}
template<int WARP_M, int kBlockK, int kHeadDimV, bool Is_dropout, typename ElementAccum>
__forceinline__ __device__ void fwd_epilugue_rescale_acco(
......@@ -28,9 +72,9 @@ __forceinline__ __device__ void fwd_epilugue_rescale_acco(
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + ni * (WARP_M / 32) + mi;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
......@@ -108,14 +152,51 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
constexpr bool Is_Interleaved_ = Is_Interleaved and kHeadDimV == 128;
#else
constexpr bool Is_Interleaved_ = Is_Interleaved;
#endif
if constexpr (Is_Interleaved_) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < (WARP_M / 32); ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int pv_tile_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
const int mmac_id = min_tile_m + min_tile_n * 2;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + min_tile_m * 16 + pv_lane_seq_idx;
// prepare for store
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
}
}
}
}
} // brace, to control vgpr usage
#else
// simulate mmac-4interleave via lds
// todo: lds bank conflicts, vgpr spills
#pragma unroll
for(int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
......@@ -128,14 +209,17 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
int s_offset = k_loop * kBlockK;
int seqlen_q_offset = (warp_id * WARP_M + warp_m_idx * 32 + pv_lane_seq_idx * 2 + min_tile_m);
int v_offset = seqlen_q_offset * seqlen_o_stride + pv_lane_head_dim_idx * 8;
// prepare vgprs
union_vec4_f16x2<Element> v_data;
#pragma unroll
for(int vec_index = 0; vec_index < 4; ++vec_index) {
// convert float -> bf16/fp16
constexpr bool is_bf16 = std::is_same<Element, bhalf_t>::value;
v_data.f16x2[vec_index][0] = DownCast<ElementAccum, Element, is_bf16>(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data.f16x2[vec_index][1] = DownCast<ElementAccum, Element, is_bf16>(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
}
// try interleave
auto lds = (__attribute__((address_space(3))) float*)(0);
int lds_write_offset = (warp_id * 512 + pv_lane_seq_idx * 16 + pv_lane_head_dim_idx * 4 + pv_lane_seq_idx * 4) * 4;
__builtin_amdgcn_sched_barrier(0);
......@@ -148,6 +232,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
flash::wait_lds_data_arrived<false>(0);
// write to global memory
if constexpr (Is_even_MN) {
*(vec4_fp32*)(o_ptr + v_offset + s_offset + k_tile_idx * 32) = v_data.f32;
} else {
......@@ -159,42 +244,9 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
}
}
#else
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < (WARP_M / 32); ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int pv_tile_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
const int mmac_id = min_tile_m + min_tile_n * 2;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + min_tile_m * 16 + pv_lane_seq_idx;
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
}
}
}
}
}
#endif
} else {
auto gO = prepare_for_buffer_load<kHeadDimV, Element, TcpSwizzle>(o_ptr);
auto o_resource = prepare_for_buffer_load<kHeadDimV, Element, TcpSwizzle>(o_ptr);
#pragma unroll
for(int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
......@@ -234,7 +286,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
// write to global memory
if constexpr (Is_even_MN) {
inline_buffer_store_dword<vec2_Element<Element>, 1>(v_data, v_offset, gO, s_offset, /* immediate integer */s_offset_constexpr);
inline_buffer_store_dword<vec2_Element<Element>, 1>(v_data, v_offset, o_resource, s_offset, /* immediate integer */s_offset_constexpr);
} else {
if(m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(vec2_Element<Element>*)(o_ptr + v_offset + s_offset + s_offset_constexpr) = v_data;
......@@ -252,4 +304,4 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
}
}
\ No newline at end of file
}
#include "numeric_types.h"
#include "intrinsic.h"
template<int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, bool Is_even_MN, bool Is_Interleaved, bool TcpSwizzle, typename Element, typename ElementAccum>
__forceinline__ __device__ void fwd_epilogue_store_output_mls_gfx92a(
Element *o_ptr,
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4],
int m_block,
int warp_id,
int lane_id,
int seqlen_o_stride,
int seqlen_q_limit) {
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
// MLS gfx92a PV accumulators are laid out as 4-interleaved rows. Keep
// this store path private to the MLS gfx92a kernels so the generic fwd
// epilogue can continue to serve the legacy FA_FWD_NO_MLS path unchanged.
if constexpr (false) {
} else {
auto gO = prepare_for_buffer_load<kHeadDimV, Element, TcpSwizzle>(o_ptr);
#pragma unroll
for(int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for(int warp_m_idx = 0; warp_m_idx < (WARP_M / 32); ++warp_m_idx) {
#pragma unroll
for(int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll
for(int vec_index = 0; vec_index < 4; ++vec_index) {
int tile32x32_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
int s_offset = k_loop * kBlockK;
int s_offset_constexpr = k_tile_idx * 32 + vec_index * 8;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + pv_lane_seq_idx + min_tile_m * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + pv_lane_head_dim_idx * 2;
vec2_Element<Element> v_data;
if constexpr (std::is_same<Element, bhalf_t>::value) {
*(vec2_Element<Element>*)&v_data = DownCastPairNoPack<ElementAccum, Element>(
acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index],
acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]
);
}
else if constexpr (std::is_same<Element, half_t>::value) {
#ifdef USE_CVT_PKRTZ_FP16_FP32
*(vec2_Element<Element>*)&v_data = DownCastPair<ElementAccum, Element>(
acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index],
acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]
);
#else
v_data[0] = DownCast<ElementAccum, Element>(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data[1] = DownCast<ElementAccum, Element>(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
#endif
}
if constexpr (Is_even_MN) {
inline_buffer_store_dword<vec2_Element<Element>, 1>(v_data, v_offset, gO, s_offset, s_offset_constexpr);
} else {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
inline_buffer_store_dword<vec2_Element<Element>, 1>(v_data, v_offset, gO, s_offset, s_offset_constexpr);
}
}
}
}
}
}
}
flash::wait_buffer_data_arrived<true>(0);
}
}
template<int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, bool Is_even_MN, bool Is_Interleaved, bool TcpSwizzle, typename Element, typename ElementAccum>
__forceinline__ __device__ void fwd_epilogue_store_output_gfx92a(
Element *o_ptr,
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4],
int m_block,
int warp_id,
int lane_id,
int seqlen_o_stride,
int seqlen_q_limit) {
fwd_epilogue_store_output_mls_gfx92a<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleaved, TcpSwizzle, Element, ElementAccum>(
o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, seqlen_q_limit);
}
#include "fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h"
#include "static_switch.h"
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_gfx92a(
vec4_uint v_ptr,
vec4_uint k_ptr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockK / 32)][4],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 32) * (kBlockN / 32)][4],
int warp_id,
int seqlen_k_stride,
int seqlen_v_stride,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int V_LDS_LOAD_NUM = (kHeadDimV * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 32 and "Error: To simplify, only WARP_M = 32 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// Prepare V regs
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// Prepare V lds offset
int v_lds_base = 0; // reinterpret_cast<size_t>(v_lds); // ===> 性能下降 ?
// Prepare MLS buffer resource sregs
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
// Main loop across blockN(128) among seqlenkv
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// Do k-dim interleave for next mmac
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 0 + 0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 0 + 1].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 1 + 0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 1 + 1].f16x4);
#endif
// MLS dispatch
if constexpr (Is_even_MN) {
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES);
v_srsrc[3] = 0x20000;
} else {
int nm_filter_max = n_loop * WARP_K + 32 - max_seq_kv_offset;
int real_mls_loop = nm_filter_max >= 32 ? 0 : n_loop;
int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES);
}
int lds_write_offset = (lds_stage_id * WARP_K * kHeadDimV + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_write_offset, 0);
// Wait buffer
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
// DS dispatch
int lds_load_offset = (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
// Wait ds_mpermute
#if defined(__gfx92a__)
flash::wait_lds_data_arrived<false>(2);
#endif
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// DS dispatch
int lds_load_offset = (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
// Wait DS
flash::wait_lds_data_arrived<false>(3);
// MMAC
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
{
constexpr int min_tile_k = 0;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
}
// Prefetch K
if constexpr (PREFETCH_K) {
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset);
}
{
constexpr int n_loop = 4;
// Do k-dim interleave for next mmac
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 0 + 0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 0 + 1].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 1 + 0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[n_loop - 1][2 * 1 + 1].f16x4);
#endif
lds_stage_id ^= 1;
int stage_id = 0;
// Wait buffer
if constexpr (PREFETCH_K) {
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
} else {
flash::wait_buffer_data_arrived<true>(0);
}
// DS dispatch
int lds_load_offset = (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
// Wait ds_mpermute
#if defined(__gfx92a__)
flash::wait_lds_data_arrived<false>(2);
#endif
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// DS dispatch
int lds_load_offset = (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
// Wait DS
flash::wait_lds_data_arrived<false>(3);
// MMAC
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
{
constexpr int min_tile_k = 0;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
flash::raise_priority(1);
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][2 * min_tile_k + min_tile_m].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
flash::lower_priority();
}
}
}
#pragma once
#include "fwd/gfx938/pv_gemm_utils_mls_ds.h"
#include "fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_gfx92a_2TG(
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / 32)][4],
int warp_id,
int seqlen_k_stride,
int seqlen_v_stride,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 32 and "To simplify, only WARP_M = 32 is supported!");
static_assert (WARP_N == 32 and "To simplify, only WARP_N = 32 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int k_lds_load_num = WARP_N * kHeadDim / (32 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
// 准备 K 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (WARP_N * kBlockK) / (32 * 32) * 2];
// 计算 K lds 起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
// here, v_mov_b64 can be applied
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 32) * (kBlockN / 32); ++i) { // for kBlockN = 64, only wave 0 get the right QK gemm results
for (int j = 0; j < 4; ++j) {
s_reg[i][j].u64[0] = 0;
s_reg[i][j].u64[1] = 0;
}
}
}
flash::lower_priority();
// MLS
vec4_uint k_srsrc;
k_srsrc[2] = seqlen_k_stride; // stride
k_srsrc[3] = 0;
#pragma unroll
for(int n_loop = 0; n_loop < (kBlockN / WARP_N); ++n_loop) {
// Wait global data
flash::wait_buffer_data_arrived<true>(kBlockN / WARP_N - n_loop - 1);
// DS
int stage_id = 0;
{
constexpr int k_loop = 0;
int lds_load_offset = k_lds_base + (n_loop * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
stage_id ^= 1;
#pragma unroll
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// DS
int lds_load_offset = k_lds_base + (n_loop * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
flash::wait_lds_data_arrived<false>(3);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::lower_priority();
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
flash::raise_priority();
// last mmac
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::lower_priority();
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__) || defined(__gfx946__)
prefetch_v_to_lds_mls_ds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 2, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset);
#else
#endif
}
} // qk_gemm
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_gfx92a(
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / 32)][4],
int warp_id,
int seqlen_k_stride,
int seqlen_v_stride,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 32 and "To simplify, only WARP_M = 32 is supported!");
static_assert (WARP_N == 32 and "To simplify, only WARP_N = 32 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int k_lds_load_num = WARP_N * kHeadDim / (32 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
// Prepare regs for k
union_vec4_f16x2<Element> k_reg[STAGES * (WARP_N * kBlockK) / (32 * 32) * 2];
// Zero-initialize s_reg
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 32) * (kBlockN / 32); ++i) { // for kBlockN = 64, only wave 0 get the right QK gemm results
for (int j = 0; j < 4; ++j) {
s_reg[i][j].u64[0] = 0;
s_reg[i][j].u64[1] = 0;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Prepare MLS buffer resource sregs
vec4_uint k_srsrc;
k_srsrc[2] = seqlen_k_stride; // stride
k_srsrc[3] = 0;
int n_stage_id = 1;
#pragma unroll
for(int n_loop = 1; n_loop < (kBlockN / WARP_N); ++n_loop) {
// MLS dispatch
const bool has_tail = max_seq_k_offset % kBlockN != 0;
const int nm_filter_max = n_loop * WARP_N + 32 - max_seq_k_offset;
const int k_load_loop = has_tail && nm_filter_max >= 32 ? 0 : n_loop;
const int nm_filter = inline_min_max<0, 31>(k_load_loop * WARP_N + 32 - max_seq_k_offset);
const int __nm_filter = __builtin_amdgcn_readfirstlane(nm_filter);
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (k_load_loop * WARP_N * seqlen_k_stride + warp_id * 32) * ELEMENT_BYTES);
k_srsrc[3] = has_tail ? __nm_filter << 8 : 0; // set only once
int lds_write_offset = (n_stage_id * WARP_N * kHeadDim + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // sync lds usage when ping-pong
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_write_offset, 0);
// Wait MLS
n_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
// DS dispatch
{
constexpr int k_loop = 0;
int lds_load_offset = (n_stage_id * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
stage_id ^= 1;
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// DS dispatch
int lds_load_offset = (n_stage_id * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
// Wait DS
flash::wait_lds_data_arrived<false>(3);
// MMAC
asm volatile("s_setprio 2");
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
asm volatile("s_setprio 0");
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
asm volatile("s_setprio 2");
// MMAC
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
asm volatile("s_setprio 0");
}
{
// Wait MLS
constexpr int n_loop = 4;
n_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(0);
// DS dispatch
{
int k_loop = 0;
int lds_load_offset = (n_stage_id * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
stage_id ^= 1;
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// DS dispatch
int lds_load_offset = (n_stage_id * WARP_N * kHeadDim + k_loop * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
// Wait DS
flash::wait_lds_data_arrived<false>(3);
// MMAC
asm volatile("s_setprio 2");
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
asm volatile("s_setprio 0");
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// MMAC
asm volatile("s_setprio 2"); // flash::raise_priority 性能下降严重, 157.8 -> 148.2 tflops, strange, 需要看汇编
{ // 对比汇编, 差异就在于单独的 s_setprio 会被胡乱调度到 mmac 中间, 但这样跑出来却性能更高; 强行加 scheduled barrier 跑出来性能更低;
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * 2 + min_tile_m;
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_loop - 1][min_tile_n * 2 + min_tile_m].f32);
}
}
}
asm volatile("s_setprio 0"); // flash::lower_priority 性能下降 0.6 tflops 左右
}
if constexpr (STAGES == 2) {
prefetch_v_to_lds_mls_ds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 2, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "fwd/gfx938/qk_gemm_utils_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds_gfx92a(
vec4_uint q_ptr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
int warp_id,
int seqlen_q_stride,
int max_seq_q_offset=0) {
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int Q_LDS_LOAD_NUM = kBlockM * kBlockK / (32 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
int q_lds_base = reinterpret_cast<size_t>(q_lds);
flash::wait_lds_data_arrived<true>(0);
vec4_uint q_srsrc;
q_srsrc[2] = seqlen_q_stride;
const int q_row = warp_id * 32;
const int nm_filter = q_row + 32 - max_seq_q_offset;
const bool has_tail = max_seq_q_offset % kBlockM != 0;
const int q_load_row = has_tail && nm_filter >= 32 ? 0 : q_row;
// gfx92a has a 5-bit MLS filter field, so never encode 32.
const int q_filter = inline_min_max<0, 31>(q_load_row + 32 - max_seq_q_offset);
q_srsrc[3] = has_tail ? q_filter << 8 : 0;
int stage_id = 0;
{
int k_loop = 0;
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_ptr + (k_loop * kBlockK + q_load_row * seqlen_q_stride) * ELEMENT_BYTES);
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
}
stage_id ^= 1;
#pragma unroll
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_ptr + (k_loop * kBlockK + q_load_row * seqlen_q_stride) * ELEMENT_BYTES);
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
stage_id ^= 1;
buffer_load_lds_dwordx1_wait<Q_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int lds_load_offset = q_lds_base + (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, q_reg[(k_loop - 1) * 2].f16, q_reg[(k_loop - 1) * 2 + 1].f16, true);
flash::wait_lds_data_arrived<true>(0);
}
{
stage_id ^= 1;
buffer_load_lds_dwordx1_wait<0>();
__builtin_amdgcn_sched_barrier(0);
constexpr int k_loop = kHeadDim / kBlockK - 1;
int lds_load_offset = q_lds_base + (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, q_reg[k_loop * 2].f16, q_reg[k_loop * 2 + 1].f16, true);
}
__builtin_amdgcn_s_waitcnt(0);
flash::wait_lds_data_arrived<true>(0);
}
template<int kHeadDim, int kBlockN, int kBlockK, int WARP_NUM, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_k_to_lds_mls_ds_gfx92a(
vec4_uint k_ptr,
Element* k_lds,
int warp_id,
int seqlen_k_stride,
int max_seq_k_offset=0) {
flash::wait_all_warp_arrived();
constexpr int ELEMENT_BYTES = sizeof(Element);
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / WARP_N; ++n_loop) {
vec4_uint k_srsrc;
k_srsrc[2] = seqlen_k_stride;
const bool has_tail = max_seq_k_offset % kBlockN != 0;
const int nm_filter_max = n_loop * WARP_N + 32 - max_seq_k_offset;
const int k_load_loop = has_tail && nm_filter_max >= 32 ? 0 : n_loop;
const int nm_filter = inline_min_max<0, 31>(k_load_loop * WARP_N + 32 - max_seq_k_offset);
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (k_load_loop * WARP_N * seqlen_k_stride + warp_id * 32) * ELEMENT_BYTES);
k_srsrc[3] = has_tail ? nm_filter << 8 : 0;
int lds_offset = (n_loop * WARP_N * kHeadDim + warp_id * 32 * 32) * ELEMENT_BYTES;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
__builtin_amdgcn_sched_barrier(0);
}
#pragma once
#include "philox.cuh"
#include "../utils.h"
using namespace flash;
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask_gfx92a(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask_causal_gfx92a(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <bool HasWSLeft=true, typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask_local_gfx92a(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q,
const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ?
-INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_alibi_gfx92a(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, float g_alibi) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] += g_alibi * (col_idx - row_idx);
}
}
}
}
}
}
\ No newline at end of file
#include "numeric_types.h"
#include "intrinsic.h"
__forceinline__ __device__ float fp8_attention_sink_load(const void *s_aux_ptr, int s_aux_type, int head_idx) {
if (s_aux_type == 1) {
return reinterpret_cast<const float *>(s_aux_ptr)[head_idx];
} else if (s_aux_type == 2) {
return UpCast<half_t, float>(reinterpret_cast<const half_t *>(s_aux_ptr)[head_idx]);
} else {
return UpCast<BFloat16, float>(reinterpret_cast<const BFloat16 *>(s_aux_ptr)[head_idx]);
}
}
template<int kHeadDim, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_attention_sink_apply(
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum softmax_scale,
float sink_value
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
const ElementAccum old_scaled_max = scores_max[m_idx] * softmax_scale;
const ElementAccum new_scaled_max = max(old_scaled_max, ElementAccum(sink_value));
const ElementAccum old_rescale = __expf(old_scaled_max - new_scaled_max);
scores_sum[m_idx] = scores_sum[m_idx] * old_rescale + __expf(ElementAccum(sink_value) - new_scaled_max);
scores_max[m_idx] = new_scaled_max / softmax_scale;
__float2 old_rescale_pair;
old_rescale_pair[0] = old_rescale;
old_rescale_pair[1] = old_rescale;
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 32; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
acc_o[k_loop][m_idx][n_idx].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[0], old_rescale_pair);
acc_o[k_loop][m_idx][n_idx].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[1], old_rescale_pair);
}
}
}
}
template<bool AssumeValidRows, int kHeadDim, int WARP_M, int WARP_N, bool StoreLSE, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_rescale_acc_o(
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum lse[WARP_M / 16],
ElementAccum softmax_scale,
ElementAccum v_descale
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
ElementAccum sum = scores_sum[m_idx];
if constexpr (StoreLSE) {
lse[m_idx] = (sum == 0.f || sum != sum) ? INFINITY : __llvm_fma_f32(scores_max[m_idx], softmax_scale, __logf(sum));
}
ElementAccum total_rescale;
if constexpr (AssumeValidRows) {
total_rescale = v_descale / sum;
} else {
total_rescale = (sum == 0.f || sum != sum) ? 0.f : v_descale / sum;
}
__float2 total_scale_pair;
total_scale_pair[0] = total_rescale;
total_scale_pair[1] = total_rescale;
// __float2 inv_sum_pair;
// inv_sum_pair[0] = 1.0f / sum;
// inv_sum_pair[1] = inv_sum_pair[0];
#pragma unroll
for (int k_loop = 0; k_loop < kHeadDim / 32; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
acc_o[k_loop][m_idx][n_idx].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[0], total_scale_pair);
acc_o[k_loop][m_idx][n_idx].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[k_loop][m_idx][n_idx].u64[1], total_scale_pair);
}
}
}
}
template<bool Is_even_MN, int WARP_M, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_store_lse(
// ElementAccum* scores_max_ptr,
// ElementAccum* scores_sum_ptr,
ElementAccum* softmax_lse_ptr,
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum lse[WARP_M / 16],
int row_offset_lse, /*(bidb * h + bidh) * actual_seqlen_q*/
int actual_seqlen_q,
int wave_row_offset, /*m_block * kBlockM + warp_id * WARP_M*/
int lane_id
) {
if (lane_id < 16) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int lse_row_id = row_offset_lse + wave_row_offset + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);;
// scores_max_ptr[lse_row_id] = scores_max[m_idx];
// scores_sum_ptr[lse_row_id] = scores_sum[m_idx];
if (lse_row_id-row_offset_lse < actual_seqlen_q){
softmax_lse_ptr[lse_row_id] = lse[m_idx];
}
}
}
}
template<bool Is_even_MN, int kBlockM, int kHeadDim, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_epilogue_store_output(
Element* acc_o_ptr,
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
int m_block,
int warp_id,
int lane_id,
int o_row_stride,
int actual_seqlen_q
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
int row_idx = warp_id * WARP_M + ((lane_id & 15) >> 2) * 8 + m_idx * 4 + (lane_id & 3);
int col_idx = pv_loop * 32 + mmac_id * 16 + (lane_id >> 4) * 4;
int offset = row_idx * o_row_stride + col_idx;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_loop][m_idx][mmac_id].f32x2[vec_index]);
// v_data = __builtin_hcu_cvt_pk_f16_f32(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2], acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1], false/*clamp*/, 0/*omod*/);
// v_data[0] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2]);
// v_data[1] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1]);
}
if constexpr (Is_even_MN) {
*(union_vec2_f16x2<Element>*)(acc_o_ptr + offset) = v_data;
} else if (m_block * kBlockM + row_idx < actual_seqlen_q) {
*(union_vec2_f16x2<Element>*)(acc_o_ptr + offset) = v_data;
}
// *(vec4_fp32*)(acc_o_ptr + offset) = acc_o[pv_loop][m_idx][mmac_id].f32;
}
}
}
}
#include "fp8_qk_gemm_utils_mls_ds.h"
#include "static_switch.h"
// PrefetchK=false 版本:不 prefetch K,不需要额外参数
template<bool PrefetchK, bool Is_even_MN, int kHeadDimQK, int kHeadDimV, int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_pv_gemm_and_prefetch_k(
vec4_Accum<ElementAccum> acc_o[kHeadDimV / 32][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16],
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32],
int8_t* v_lds,
Element*& k_ptr,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset
) {
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
if constexpr (PrefetchK) {
k_ptr += kBlockN * k_row_stride;
fp8_prefetch_k_to_lds<Is_even_MN, kHeadDimQK, WARP_N, Element>(k_ptr, k_lds, warp_id, k_row_stride, max_seq_kv_offset);
}
// mmac stream
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDimV / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(p_reg[m_idx].i8x8[k_loop], v_regs[k_loop][pv_loop].i8x8[mmac_id], acc_o[pv_loop][m_idx][mmac_id].f32);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
// PrefetchK=true 版本:在 PV MMAC 期间 prefetch 下一块 K(paged KV)
template<bool Is_even_MN, int kHeadDimQK, int kHeadDimV, int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_pv_gemm_and_prefetch_k_paged(
vec4_Accum<ElementAccum> acc_o[kHeadDimV / 32][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16],
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDimV / 32],
int8_t* v_lds,
Element* k_ptr_next,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset_next
) {
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// Prefetch 下一块 K 到 LDS(与 MMAC 重叠)
__builtin_amdgcn_sched_barrier(0);
fp8_prefetch_k_to_lds<Is_even_MN, kHeadDimQK, WARP_N, Element>(k_ptr_next, k_lds, warp_id, k_row_stride, max_seq_kv_offset_next);
__builtin_amdgcn_sched_barrier(0);
// mmac stream
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDimV / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(p_reg[m_idx].i8x8[k_loop], v_regs[k_loop][pv_loop].i8x8[mmac_id], acc_o[pv_loop][m_idx][mmac_id].f32);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic_mls_ds.h"
template<bool Is_even_MN, int kBlockN, int kHeadDim, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_prefetch_v_to_lds(
Element* v_ptr,
int8_t* v_lds,
int warp_id,
int v_row_stride,
int max_seq_kv_offset
) {
static_assert(kHeadDim == 128 || kHeadDim == 256);
// 准备 MLS 寄存器, 填充 stride
vec4_uint v_root = prepare_for_matrix_load<kHeadDim, Element>(v_ptr);
vec4_uint v_srsrc;
v_srsrc[0] = v_root[0];
v_srsrc[1] = v_root[1];
v_srsrc[2] = v_row_stride; // stride
v_srsrc[3] = 0x00000;
// 4 个 wave 直接全量预取
int v_lds_write_bytes = warp_id * WARP_N * kHeadDim * sizeof(Element);
// 每次读取 32x128 的数据
// tile1: 行 [warp_id*32, warp_id*32+16)
// 整个 tile 被 filter 时保留一行合法 V,避免 0 * NaN 污染 PV。
int nm_filter_warp0_tile1 = inline_min_max<0, 16>(16 - max_seq_kv_offset);
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_kv_offset);
v_srsrc[0] = (nm_filter == 16) ? v_root[0] : v_root[0] + (warp_id * 2) * 16 * v_row_stride * sizeof(Element);
nm_filter = (nm_filter == 16) ? min(nm_filter_warp0_tile1, 15) : nm_filter;
v_srsrc[3] = v_srsrc[3] + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
flash::wait_all_warp_arrived();
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds+v_lds_write_bytes, 0, true, false, false, false, false);
if constexpr (kHeadDim == 256) {
v_srsrc[0] += 128 * sizeof(Element);
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds + v_lds_write_bytes + 4096, 0, true, false, false, false, false);
}
// tile2: 行 [warp_id*32+16, warp_id*32+32)
int nm_filter_warp0_tile2 = inline_min_max<0, 16>(32 - max_seq_kv_offset);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_kv_offset);
v_srsrc[0] = (nm_filter == 16) ? v_root[0] : v_root[0] + (warp_id * 2 + 1) * 16 * v_row_stride * sizeof(Element);
nm_filter = (nm_filter == 16) ? min(nm_filter_warp0_tile2, 15) : nm_filter;
v_srsrc[3] = ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds+v_lds_write_bytes + (128 * 16 >> 1), 0, true, false, false, false, false);
if constexpr (kHeadDim == 256) {
v_srsrc[0] += 128 * sizeof(Element);
__builtin_hcu_matrix_load_128X16_b8(v_srsrc, v_lds + v_lds_write_bytes + 4096 + (128 * 16 >> 1), 0, true, false, false, false, false);
}
}
#pragma once
#include "fp8_pv_gemm_utils_mls_ds.h"
// #define USE_DS_READ_B128_FOR_INTERLEAVE4
template<int kBlockN, int kHeadDim, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_qk_gemm(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64],
int8_t* k_lds
) {
static_assert(kHeadDim == 128 || kHeadDim == 192 || kHeadDim == 256);
constexpr int kLdsHeadDimStride = kHeadDim == 192 ? 256 : kHeadDim;
int tx = threadIdx.x;
int lane_id = tx & 63;
int row = (lane_id & 15) >> 1;
int col = lane_id >> 4;
int col_swizzle = (row + col) & 3;
// 等待 K 的数据都写到 lds 了, 4 wave 同步
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)\ns_barrier\n"); // hint: pv mmac 太短, 不够隐藏这一段时延, 可以考虑把 vmcnt 拆细一点, 先等一部分数据回来计算也可以
__builtin_amdgcn_sched_barrier(0);
// __syncthreads();
if constexpr (true) {
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8 k_regs[kBlockN / WARP_N][WARP_N / 16][kHeadDim / 64];
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0, k_regs[k_loop][0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024, k_regs[k_loop][1][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[k_loop][0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[k_loop][1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kLdsHeadDimStride;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[k_loop][0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[k_loop][1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[k_loop][0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[k_loop][1][1].i32x4, true/*transpose*/)
#pragma unroll
for (int h_idx = 0; h_idx < kHeadDim / 64; ++h_idx) {
const int h_offset = (kHeadDim == 192 && h_idx == 2) ? 3 * 2048 : h_idx * 2048;
k_regs[k_loop][0][h_idx].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + h_offset + 0, 0, 3, 1, 0);
k_regs[k_loop][1][h_idx].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + h_offset + 1024, 0, 3, 1, 0);
}
#endif
}
// init s_reg
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
}
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n":: "B"((kBlockN / WARP_N - k_loop - 1) * 4));
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// ======================================================== QK mmac ======================================================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int h_idx = 0; h_idx < kHeadDim / 64; ++h_idx) {
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][h_idx].i8x8[0], k_regs[k_loop][n_idx][h_idx].i8x8[0], s_reg[k_loop][m_idx][n_idx].f32);
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][h_idx].i8x8[1], k_regs[k_loop][n_idx][h_idx].i8x8[1], s_reg[k_loop][m_idx][n_idx].f32);
}
}
}
}
} else if constexpr (false and WARP_M == 32 and WARP_N == 32 and kBlockN == 128 and kHeadDim == 128) {
union_vec16_fp8 k_regs[WARP_N / 16][kHeadDim / 64];
{
constexpr int k_loop = 0;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0, k_regs[0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024, k_regs[1][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < 2; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(2)\n");
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][1][1].f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
#endif
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][1][1].f32);
}
{
constexpr int k_loop = 1;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < 2; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(2)\n");
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][1][1].f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
#endif
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][1][1].f32);
}
{
constexpr int k_loop = 2;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < 2; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(2)\n");
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][1][1].f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
#endif
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][1][1].f32);
}
{
constexpr int k_loop = 3;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < 2; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(2)\n");
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][1][1].f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][1][1].f32);
}
} else if constexpr (false) {
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8 k_regs[WARP_N / 16][kHeadDim / 64];
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 0, k_regs[0][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 1024, k_regs[1][0].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 2048, k_regs[0][1].i32x4);
inline_ds_read_b128_no_wait_bytes(reinterpret_cast<size_t>(k_lds) + k_lds_load_offset + 3072, k_regs[1][1].i32x4);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < 2; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < 2; ++n_idx) {
inline_vgpr4_init_zero(s_reg[k_loop][m_idx][n_idx]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(2)\n");
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[0][0].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[0][0].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[0], k_regs[1][0].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][0].i8x8[1], k_regs[1][0].i8x8[1], s_reg[k_loop][1][1].f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][0][0].f32);
s_reg[k_loop][0][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][0][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[0][1].i8x8[0], s_reg[k_loop][1][0].f32);
s_reg[k_loop][1][0].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[0][1].i8x8[1], s_reg[k_loop][1][0].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][0][1].f32);
s_reg[k_loop][0][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[0][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][0][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[0], k_regs[1][1].i8x8[0], s_reg[k_loop][1][1].f32);
s_reg[k_loop][1][1].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[1][1].i8x8[1], k_regs[1][1].i8x8[1], s_reg[k_loop][1][1].f32);
}
} else { // 寄存器占用更少的写法
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
// init s_reg
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
s_reg[k_loop][m_idx][n_idx].f32[0] = 0;
s_reg[k_loop][m_idx][n_idx].f32[1] = 0;
s_reg[k_loop][m_idx][n_idx].f32[2] = 0;
s_reg[k_loop][m_idx][n_idx].f32[3] = 0;
}
}
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8 k_regs[WARP_N / 16][kHeadDim / 64];
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int k_lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + k_loop * WARP_N * kHeadDim;
k_regs[0][0].i32x4 = *(vec4_int32*)(k_lds + k_lds_load_offset + 0);
k_regs[1][0].i32x4 = *(vec4_int32*)(k_lds + k_lds_load_offset + 1024/*ds fmt 0, dmft1 */);
k_regs[0][1].i32x4 = *(vec4_int32*)(k_lds + k_lds_load_offset + 2048);
k_regs[1][1].i32x4 = *(vec4_int32*)(k_lds + k_lds_load_offset + 3072);
#else
int k_lds_load_offset = k_loop * WARP_N * kHeadDim;
k_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 0, 0, 3, 1, 0);
k_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 1024, 0, 3, 1, 0);
k_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 2048, 0, 3, 1, 0);
k_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(k_lds + k_lds_load_offset + 3072, 0, 3, 1, 0);
#endif
// ======================================================== QK mmac ======================================================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][0].i8x8[0], k_regs[n_idx][0].i8x8[0], s_reg[k_loop][m_idx][n_idx].f32);
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][0].i8x8[1], k_regs[n_idx][0].i8x8[1], s_reg[k_loop][m_idx][n_idx].f32);
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][1].i8x8[0], k_regs[n_idx][1].i8x8[0], s_reg[k_loop][m_idx][n_idx].f32);
s_reg[k_loop][m_idx][n_idx].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(q_regs[m_idx][1].i8x8[1], k_regs[n_idx][1].i8x8[1], s_reg[k_loop][m_idx][n_idx].f32);
}
}
}
}
}
#pragma once
#include "intrinsic_mls_ds.h"
#include "intrinsic_mls_ds_b8.h"
#define USE_MLS_128B_REQUEST
template<bool Is_even_MN, int kHeadDim, int WARP_M, typename Element>
__forceinline__ __device__ void fp8_prefetch_q_to_lds(
Element* q_ptr,
int8_t* q_lds,
int warp_id,
int q_row_stride,
int max_seq_q_offset
) {
static_assert(kHeadDim == 128 || kHeadDim == 192 || kHeadDim == 256);
// 准备 MLS 寄存器, 填充 stride
vec4_uint q_root = prepare_for_matrix_load<128, Element>(q_ptr);
vec4_uint q_srsrc;
q_srsrc[0] = q_root[0];
q_srsrc[1] = q_root[1];
q_srsrc[2] = q_row_stride; // stride
q_srsrc[3] = 0x40000; // [17: 18], interleave 4
// 计算 lds 写入地址
constexpr int kLdsHeadDimStride = kHeadDim == 192 ? 256 : kHeadDim;
int q_lds_offset = warp_id * WARP_M * kLdsHeadDimStride * sizeof(Element);
int q_lds_write_bytes = reinterpret_cast<size_t>(q_lds) + q_lds_offset;
// 计算 global 读取地址
q_srsrc[0] = q_root[0] + (warp_id * WARP_M) * q_row_stride * sizeof(Element);
//边界判断
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = q_srsrc[3] + max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8; // set only once
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8)); // set only once
// printf("nm_filter is %d, max_seq_q_osffset is %d\n", max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8, max_seq_q_offset);
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
// inline_matrix_load_128x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_offset/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_q_offset);
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset+512, 16, true, false, false, false, false);
if constexpr (kHeadDim > 128) {
constexpr int kSecondMlsLoadHeadOffset = kHeadDim == 192 ? 64 : 128;
q_srsrc[0] = q_root[0] + (warp_id * WARP_M) * q_row_stride * sizeof(Element) + kSecondMlsLoadHeadOffset * sizeof(Element);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_q_offset);
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset+4096, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_q_offset);
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset+4608, 16, true, false, false, false, false);
}
#else
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_write_bytes/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
q_srsrc[0] = q_srsrc[0] + 64 * sizeof(Element);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_write_bytes + 2048/*lds bytes*/, 0/*matrix_offset, 0 or 16*/); // Q 部分可以考虑 128x16 或者非 4-interleave 形式
#endif
__builtin_amdgcn_sched_barrier(0);
}
// #define USE_DS_READ_B128_FOR_INTERLEAVE4
template<int kHeadDim, int WARP_M, typename Element>
__forceinline__ __device__ void load_q_from_lds_to_vgpr(
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64],
int8_t* q_lds,
int warp_id,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
static_assert(kHeadDim == 128 || kHeadDim == 192 || kHeadDim == 256);
constexpr int kLdsHeadDimStride = kHeadDim == 192 ? 256 : kHeadDim;
// lds 写到两个地方去了, 注意是 rearrange, 所以跳 1K; transpose 跳 2K
// MLS0: [0: 512) 和 [1024, 1536)
// MLS1: [512: 1024) 和 [1536, 2048)
// 分 4 次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int row = (lane_id & 15) >> 1;
int col = lane_id >> 4;
int col_swizzle = (row + col) & 3;
int lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + warp_id * WARP_M * kHeadDim;
q_regs[0][0].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 0);
q_regs[1][0].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 1024/*ds fmt 0, dmft1 */);
q_regs[0][1].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 2048/*ds fmt 0, dmft1 */);
q_regs[1][1].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 3072/*ds fmt 0, dmft1 */);
#else
#pragma unroll
for (int h_idx = 0; h_idx < kHeadDim / 64; ++h_idx) {
const int h_offset = (kHeadDim == 192 && h_idx == 2) ? 3 * 2048 : h_idx * 2048;
q_regs[0][h_idx].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + h_offset + 0 + warp_id * WARP_M * kLdsHeadDimStride, 0, 3, 1, 0);
q_regs[1][h_idx].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + h_offset + 1024 + warp_id * WARP_M * kLdsHeadDimStride, 0, 3, 1, 0);
}
#endif
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
template<bool Is_even_MN, int kHeadDim, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_prefetch_k_to_lds(
Element* k_ptr,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset
) {
static_assert(kHeadDim == 128 || kHeadDim == 192 || kHeadDim == 256);
// 准备 MLS 寄存器, 填充 stride
vec4_uint k_root = prepare_for_matrix_load<kHeadDim, Element>(k_ptr);
vec4_uint k_srsrc;
k_srsrc[0] = k_root[0];
k_srsrc[1] = k_root[1];
k_srsrc[2] = k_row_stride; // stride
k_srsrc[3] = 0x40000; // [17: 18], interleave 4
// 计算 lds 写入地址
constexpr int kLdsHeadDimStride = kHeadDim == 192 ? 256 : kHeadDim;
int k_lds_offset = warp_id * WARP_N * kLdsHeadDimStride * sizeof(Element);
int k_lds_write_bytes = reinterpret_cast<size_t>(k_lds) + k_lds_offset;
// 计算 global 读取地址
k_srsrc[0] = k_root[0] + warp_id * 32 * k_row_stride * sizeof(Element);
//边界判断
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_kv_offset);
k_srsrc[3] = k_srsrc[3] + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
// 同步所有warp,确保srsrc参数准备完毕后再发起MLS load
flash::wait_all_warp_arrived();
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_kv_offset);
k_srsrc[3] = 0x40000 + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset+512, 16, true, false, false, false, false);
if constexpr (kHeadDim > 128) {
constexpr int kSecondMlsLoadHeadOffset = kHeadDim == 192 ? 64 : 128;
k_srsrc[0] = k_root[0] + warp_id * 32 * k_row_stride * sizeof(Element) + kSecondMlsLoadHeadOffset * sizeof(Element);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_kv_offset);
k_srsrc[3] = 0x40000 + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset+4096, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_kv_offset);
k_srsrc[3] = 0x40000 + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset+4608, 16, true, false, false, false, false);
}
#else
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, k_lds_write_bytes/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
k_srsrc[0] = k_srsrc[0] + 64 * sizeof(Element);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, k_lds_write_bytes + 2048/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
#endif
}
#pragma once
#include "philox.cuh"
#include "../utils.h"
using namespace flash;
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int max_seq_kv_offset,
int wave_col_offset,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] =
(n_base + k_offset + vec_idx >= max_seq_kv_offset)
? -INFINITY
: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_causal_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int actual_seqlen_q,
int actual_seqlen_k,
int wave_row_offset,
int wave_col_offset,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int row_base = wave_row_offset + ((lane_id & 15) >> 2) * 8 + (lane_id & 3);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
const int causal_limit = actual_seqlen_k - actual_seqlen_q;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
const int row_idx = row_base + m_idx * 4;
const int col_limit = min(actual_seqlen_k, row_idx + causal_limit);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] = (n_base + k_offset + vec_idx > col_limit) ? -INFINITY: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_local_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int actual_seqlen_q,
int actual_seqlen_k,
int wave_row_offset,
int wave_col_offset,
int window_size_left,
int window_size_right,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int row_base = wave_row_offset + ((lane_id & 15) >> 2) * 8 + (lane_id & 3);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
const bool has_ws_left = window_size_left >= 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
const int row_idx = row_base + m_idx * 4;
const int col_limit_left = max(0, row_idx + 1 + actual_seqlen_k - actual_seqlen_q - window_size_left);
const int col_limit_right = min(actual_seqlen_k, row_idx + actual_seqlen_k - actual_seqlen_q + window_size_right);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = n_base + k_offset + vec_idx;
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] =
(col_idx > col_limit_right || (has_ws_left && col_idx < col_limit_left - 1))
? -INFINITY
: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_qk_descale(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
__float2 qk_descale
) {
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = __builtin_hcu_pk_mul_f32(s_reg[k_loop][m_idx][n_idx].u64[vec_idx], qk_descale);
// s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = s_reg[k_loop][m_idx][n_idx].u64[vec_idx] * qk_descale;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<bool AssumeValidRows, int kHeadDim, int kBlockN, int WARP_M, int WARP_N, int WARP_K, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_softmax_and_schedule_v(
/*softmax module related args*/
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
ElementAccum softmax_scale_log2,
/*scheduled modules related args*/
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32],
int8_t* v_lds
) {
// ======================================================== Max ======================================================================
ElementAccum scores_max_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
ElementAccum max_value = scores_max[m_idx];
// 当前线程遍历 4 个 32x32x32 mmac 输出的 f32x4
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
max_value = max(max_value, s_reg[k_loop][m_idx][n_idx].f32[vec_idx]);
}
}
}
// 这一行比较 0, 16, 32, 48 号线程的数据
max_value = max(max_value, __shfl_xor_tmp(max_value, 32));
max_value = max(max_value, __shfl_xor_tmp(max_value, 16));
// 赋值给最终的最大值
scores_max_cur[m_idx] = max_value;
}
// ========================================== softmax ===============================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__float2 max_scaled_pair;
if constexpr (AssumeValidRows) {
max_scaled_pair[0] = -scores_max_cur[m_idx] * softmax_scale_log2;
} else {
max_scaled_pair[0] = scores_max_cur[m_idx] == -INFINITY ? 0.f: -scores_max_cur[m_idx] * softmax_scale_log2;
}
max_scaled_pair[1] = max_scaled_pair[0];
__float2 softmax_scale_log2_pair = {softmax_scale_log2, softmax_scale_log2};
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = __builtin_hcu_pk_fma_f32(s_reg[k_loop][m_idx][n_idx].u64[vec_idx], softmax_scale_log2_pair, max_scaled_pair);
}
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] = __llvm_exp2_f32(s_reg[k_loop][m_idx][n_idx].f32[vec_idx]);
}
}
}
}
// ========================================== Sum ===============================================
ElementAccum scores_sum_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
vec2_Accum<ElementAccum> sum_pair;
sum_pair.data = 0.0;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
sum_pair.u64 = __builtin_hcu_pk_add_f32(sum_pair.u64, s_reg[k_loop][m_idx][n_idx].u64[0]);
sum_pair.u64 = __builtin_hcu_pk_add_f32(sum_pair.u64, s_reg[k_loop][m_idx][n_idx].u64[1]);
}
}
scores_sum_cur[m_idx] = sum_pair.f32[0] + sum_pair.f32[1];
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor_tmp(scores_sum_cur[m_idx], 32);
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor_tmp(scores_sum_cur[m_idx], 16);
}
// 更新 scores_sum, scores_max
// 这段代码放在这是因为即将下发的大量 ds 指令, 会跟 __shfl_xor 抢带宽, 导致时延太高
// ElementAccum exp_rescale[WARP_M / 16];
// #pragma unroll
// for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
// exp_rescale[m_idx] = __llvm_exp2_f32((scores_max[m_idx] - scores_max_cur[m_idx]) * softmax_scale_log2);
// scores_max[m_idx] = scores_max_cur[m_idx];
// scores_sum[m_idx] = __llvm_fma_f32(scores_sum[m_idx], exp_rescale[m_idx], scores_sum_cur[m_idx]);
// }
// ========================================== schedule V ===============================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)\n\ts_barrier\n");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
// 用 ds_read_matrix 从 lds 读取数据到寄存器
int8_t* lds_load_ptr = v_lds + k_loop * WARP_M * kHeadDim * sizeof(Element);
v_regs[k_loop][0].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr, 0, 2, 2, 0);
v_regs[k_loop][1].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 32, 0, 2, 2, 0);
v_regs[k_loop][2].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 128 * 16, 0, 2, 2, 0);
v_regs[k_loop][3].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 128 * 16 + 32, 0, 2, 2, 0);
if constexpr (kHeadDim == 256) {
v_regs[k_loop][4].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 4096, 0, 2, 2, 0);
v_regs[k_loop][5].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 4096 + 32, 0, 2, 2, 0);
v_regs[k_loop][6].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 4096 + 128 * 16, 0, 2, 2, 0);
v_regs[k_loop][7].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 4096 + 128 * 16 + 32, 0, 2, 2, 0);
}
}
__builtin_amdgcn_sched_barrier(0); // hint: 这里考虑只发一部分的 ds_read_matrix 指令出去, 一面堵住
// ========================================== rescale ===============================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; m_idx += 1) {
if (scores_sum[m_idx] != 0.f && scores_max[m_idx] < scores_max_cur[m_idx]) {
__float2 scores_scale_pair;
float max_diff;
if constexpr (AssumeValidRows) {
max_diff = scores_max[m_idx] - scores_max_cur[m_idx];
} else {
// Fix: 当 scores_max 和 scores_max_cur 都是 -INFINITY 时,(-INF) - (-INF) = NaN
// 这种情况发生在某些 query 行完全没有有效的 KV 可以 attend 时
max_diff = (scores_max[m_idx] == -INFINITY || scores_max_cur[m_idx] == -INFINITY)
? 0.f : (scores_max[m_idx] - scores_max_cur[m_idx]);
}
scores_scale_pair[0] = __llvm_exp2_f32(max_diff * softmax_scale_log2);
scores_scale_pair[1] = scores_scale_pair[0];
scores_sum[m_idx] *= scores_scale_pair[0];
// 放缩 acc_o
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / WARP_N; ++pv_loop) {
#pragma unroll
for (int mmac_id = 0; mmac_id < WARP_K / 16; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[pv_loop][m_idx][mmac_id].u64[0], scores_scale_pair);
acc_o[pv_loop][m_idx][mmac_id].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[pv_loop][m_idx][mmac_id].u64[1], scores_scale_pair);
}
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max[m_idx] = scores_max_cur[m_idx];
scores_sum[m_idx] += scores_sum_cur[m_idx];
}
}
template<int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_cvt_f32_to_fp8(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16]
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[k_loop][m_idx][n_idx].f32, p_reg[m_idx].i32[k_loop * 2 + n_idx]);
}
}
}
}
......@@ -13,11 +13,11 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
int seqlen_o_stride,
int seqlen_q_limit) {
static_assert (Is_Interleaved and "For fwd_epilogue_store_output_gfx938, mmac must be 4interleave");
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
static_assert (Is_Interleaved and "For fwd_epilogue_store_output_gfx938, mmac must be 4interleave");
if constexpr (TailTile16 == 2) {
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
......@@ -46,7 +46,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
}
}
}
}
} // brace, to control vgpr usage
} else {
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
......@@ -61,11 +61,13 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
const int pv_tile_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
const int mmac_id = min_tile_m + min_tile_n * 2;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + min_tile_m * 16 + pv_lane_seq_idx;
// prepare for store
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
......@@ -79,7 +81,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
}
}
}
}
} // brace, to control vgpr usage
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
......@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
} else if constexpr (kHeadDimV == 192) {
int warp_id_m = warp_id % 2; // w0 w2
int warp_id_n = warp_id / 2; // w1 w3
......@@ -76,10 +73,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 16, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
// DS
lds_stage_id ^= 1;
......@@ -165,10 +159,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
}
}
......@@ -241,10 +232,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 16, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
lds_stage_id ^= 1;
......@@ -368,4 +356,4 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
prefetch_k_to_lds_mls_ds<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
}
}
\ No newline at end of file
......@@ -34,9 +34,7 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds(
int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 防止写 v lds 和读 k lds 冲突, qk 可能有的 warp 没结束
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
......@@ -79,10 +79,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
} else if constexpr (kHeadDim == 192) {
int warp_id_m = warp_id / 2;
int warp_id_n = warp_id % 2;
......@@ -95,10 +92,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// Wait MLS
......@@ -178,10 +172,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
}
......@@ -242,10 +233,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// Wait MLS
......@@ -361,7 +349,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
flash::wait_all_warp_arrived();
if constexpr (STAGES == 2) {
#if defined(__gfx938__) // 有的 prefetch v 写到了 mha 主 kernel 代码里
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE)) // 有的 prefetch v 写到了 mha 主 kernel 代码里
prefetch_v_to_lds_mls_ds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, TailTile16, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset);
#else
......@@ -369,3 +357,4 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
} // qk_gemm
......@@ -35,10 +35,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once
}
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
}
stage_id ^= 1;
......@@ -50,10 +47,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once
}
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
stage_id ^= 1;
// DS
......@@ -63,6 +57,8 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
int lds_load_offset = q_lds_base + (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
#ifdef __gfx938__
DS_READ_MATRIX_32X32_B16(lds_load_offset, q_reg[(k_loop - 1) * 2].f16, q_reg[(k_loop - 1) * 2 + 1].f16, true);
#elif defined(__gfx946__) || defined(__gfx92a__)
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, q_reg[(k_loop - 1) * 2].f16, q_reg[(k_loop - 1) * 2 + 1].f16, true);
#endif
// __syncthreads();
flash::wait_lds_data_arrived<true>(0);
......@@ -77,6 +73,8 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
int lds_load_offset = q_lds_base + (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
#ifdef __gfx938__
DS_READ_MATRIX_32X32_B16(lds_load_offset, q_reg[k_loop * 2].f16, q_reg[k_loop * 2 + 1].f16, true);
#elif defined(__gfx946__) || defined(__gfx92a__)
DS_READ_MATRIX_32X32_B16_GFX946(lds_load_offset, q_reg[k_loop * 2].f16, q_reg[k_loop * 2 + 1].f16, true);
#endif
}
__builtin_amdgcn_s_waitcnt(0);
......@@ -114,9 +112,7 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds(
}
int lds_offset = (stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
......@@ -171,10 +171,10 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + mi + ni * (WARP_M / 32);
int mmac_id = min_tile_n * 2 + min_tile_m;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
......@@ -200,8 +200,8 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum_cur);
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......@@ -210,7 +210,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__))
inlineasm_fa_v_mov_b64(
scores_max[mi].u64,
scores_max_cur[mi].u64
......
......@@ -228,7 +228,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
......@@ -236,7 +236,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -262,7 +262,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
......@@ -270,7 +270,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -372,15 +372,14 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int qk_tile_id = mi + ni * (WARP_M / 32);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
}
......@@ -418,6 +417,7 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
? scores_max_cur[mi * 2].f32[min_tile_m]
: (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]);
// optimization from flash-attention-4
if (IsInference or scores_max[mi * 2].f32[min_tile_m] < scores_max_cur_reg) {
float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi * 2].f32[min_tile_m] *= scores_scale;
......@@ -428,13 +428,17 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + mi + ni * (WARP_M / 32);
int mmac_id = min_tile_n * 2 + min_tile_m;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
......@@ -460,25 +464,17 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum_cur);
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
#else // for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results
#else
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
inlineasm_fa_v_mov_b64(
scores_max[mi].u64,
scores_max_cur[mi].u64
);
#else
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
#endif
}
}
};
......@@ -496,7 +492,12 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
#else
if constexpr (IsInference) {
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPairNoPack<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0],
......@@ -507,6 +508,7 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]
);
} else {
// For training, higher precision is needed
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
......@@ -516,15 +518,6 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
}
#else
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment