Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#pragma once
#include "numeric_types.h"
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void fp8_mla_epilugue_rescale_acco_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum v_descale) {
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? v_descale : v_descale / sum;
__float2 scale_pair = {inv_sum, inv_sum};
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for (int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].f32[vec_id] *= inv_sum;
}
#endif
}
}
}
}
}
}
\ No newline at end of file
#pragma once
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_mla_apply_mask_gfx938(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 8;
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_mla_apply_mask_causal_gfx938_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int mtp, const int layout) {
const int MTP_REGROUP_COUNT = max_seqlen_q / mtp;
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 8;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / MTP_REGROUP_COUNT);
const int col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_mla_apply_descale_gfx938(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const __float2 qk_descale) {
#pragma unroll
for (int i = 0; i < M_WARP_COUNT * N_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
}
}
}
}
\ No newline at end of file
#pragma once
#include "fp8_mla_tp8_pv_gemm_utils_gfx938.h"
#include "fp8_mla_tp8_qk_gemm_utils_gfx938.h"
template<bool PrefetchK, int K_LOOP_COUNT, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_K_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT, typename V_Element, typename P_Element, typename ElementAccum>
__forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint& k_addr,
V_Element* v_lds,
V_Element* k_lds,
union_vec2_f16x2<P_Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_v_offset,
int64_t k_addr_offset) {
static_assert (K_LOOP_COUNT % 2 == 0);
constexpr int K_LOOP_COUNT_ = K_LOOP_COUNT / (64 / kBlockN);
constexpr int PREFETCH = 2;
// 防止与多 wave reduce max 需要的 lds 冲突
flash::wait_lds_data_arrived<true/*sync*/>(0);
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
// pingpong
int stage_id = 1;
#pragma unroll
for (int k_loop = K_LOOP_COUNT_ - 1 - PREFETCH; k_loop >= 1; k_loop -= PREFETCH) {
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
// lds 的写入地址
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
// global 随着 warp 的地址偏移
int warp_global_bytes; // = warp_id * 32 * v_row_stride * sizeof(V_Element);
// global 随着 k_loop 的地址偏移
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
// 计算边界
if constexpr (true) {
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = nm_filter << 8;
v_srsrc[3] += 0x20000;
}
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash::wait_buffer_data_arrived<false/*sync*/>(PREFETCH);
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
// 分配 v 计算 mmac 需要的寄存器资源
union_vec16_fp8 v_regs[2];
// 从 lds 读取数据到寄存器
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
// mmac
// P, fp16, 半精度
// V, fp8
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
// wait data written to registers
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// 16 fp8 for ds32x32_b8
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 = mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32
);
}
}
}
}
}
}
// 处理 K
*(int64_t*)&k_addr += k_addr_offset;
if constexpr (PrefetchK) {
fp8_mla_tp8_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockK);
flash::wait_buffer_data_arrived<false/*sync*/>(1);
} else {
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
{
constexpr int k_loop = 1 - PREFETCH;
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
// 分配 v 计算 mmac 需要的寄存器资源
union_vec16_fp8 v_regs[2];
// 从 lds 读取数据到寄存器
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
// mmac
// P, fp16, 半精度
// V, fp8
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
// wait data written to registers
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// 16 fp8 for ds32x32_b8
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 = mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32
);
}
}
}
}
}
}
flash::wait_lds_data_arrived<true/*sync*/>(0); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once
#include "fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h"
template<int K_LOOP_COUNT, int kBlockN, int WARP_NUM, typename V_Element>
__forceinline__ __device__ void fp8_mla_tp8_prefetch_v_gfx938(
vec4_uint v_addr,
V_Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset=0) {
static_assert (K_LOOP_COUNT % 2 == 0);
constexpr int K_LOOP_COUNT_ = K_LOOP_COUNT / (64 / kBlockN);
constexpr int PREFETCH = 2;
// 防止与多 wave reduce max 需要的 lds 冲突
flash::wait_lds_data_arrived<true/*sync*/>(0);
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
// pingpong
int stage_id = 0;
{
int k_loop = K_LOOP_COUNT_ - 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
// 准备读取 V 32x64 个 fp8
// lds 的写入地址
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
// global 随着 warp 的地址偏移
int warp_global_bytes; // = warp_id * 32 * v_row_stride * sizeof(V_Element);
// global 随着 k_loop 的地址偏移
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
// 计算边界
if constexpr (true) {
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = nm_filter << 8;
v_srsrc[3] += 0x20000;
}
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
}
}
#pragma once
#include "fp8_mla_tp8_qk_gemm_utils_gfx938.h"
template<int kHeadDim, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_mla_tp8_qk_gemm_gfx938(
vec4_uint k_addr,
Element* k_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int k_row_stride,
int max_seq_k_offset=0) {
int stage_id = 0;
// 准备 MLS resource 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
// 初始化 s
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
asm volatile(
"v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
: "=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[0]), "=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
}
}
}
// round
stage_id ^= 1;
#pragma unroll
for (int k_loop = 1; k_loop < kHeadDim / 64; ++k_loop) {
// lds 的写入地址
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
// global 随着 warp 的地址偏移
int warp_global_bytes; // = warp_id * 32 * k_row_stride * sizeof(Element);
// global 随着 k_loop 的地址偏移
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
// 计算边界
if constexpr (true) {
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
k_srsrc[3] += 0x40000;
}
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash::wait_buffer_data_arrived<false/*sync*/>(1);
// round
stage_id ^= 1;
// 分配 k 计算 mmac 需要的寄存器资源
union_vec16_fp8 k_regs[WARP_N / 16];
// 从 lds 读取数据到寄存器
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
// mmac
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 等待数据写到寄存器
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32
);
}
}
}
}
{
constexpr int k_loop = kHeadDim / 64;
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash::wait_buffer_data_arrived<false/*sync*/>(0);
stage_id ^= 1;
// 分配 k 计算 mmac 需要的寄存器资源
union_vec16_fp8 k_regs[WARP_N / 16];
// 从 lds 读取数据到寄存器
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
// mmac
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 等待数据写到寄存器
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave_b8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32
);
}
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
flash::wait_lds_data_arrived<true/*sync*/>(0);
} // qk_gemm
#pragma once
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds_b8.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, typename ElementAccum, int STAGES, int M_MMAC_COUNT>
__forceinline__ __device__ void fp8_mla_tp8_prefetch_q_to_vgpr_gfx938_with_initialization(
vec4_uint q_addr,
Element* q_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
int warp_id,
int q_row_stride,
int max_seq_q_offset,
vec2_Accum<ElementAccum> scores_max[WARP_M / 32],
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32],
vec4_Accum<ElementAccum> acc_o[kHeadDimV / kBlockK][4]) {
// 准备 MLS 寄存器
vec4_uint q_srsrc;
q_srsrc[0] = q_addr[0];
q_srsrc[1] = q_addr[1];
q_srsrc[2] = q_row_stride;
q_srsrc[3] = 0;
// 计算 lds 写入地址
int q_lds_write_bytes = warp_id * 16 * 128 * sizeof(Element);
// 计算 global 读取地址
int q_mls_warp_global_offset = warp_id * 128 * sizeof(Element);
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_mls_warp_global_offset);
// mls 读取 16x128 bytes
if constexpr (true) {
int nm_filter = inline_min_max<0, 16>(16 - max_seq_q_offset);
q_srsrc[3] = nm_filter << 8;
}
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds, q_srsrc, q_lds_write_bytes, 0);
// add alu between def-use
attention_initialize<kHeadDimV / kBlockK, WARP_M / 32, 1, M_MMAC_COUNT, ElementAccum>(scores_max, scores_sum, acc_o);
// 等待 4 个 warp 数据写入 lds 完毕
flash::wait_buffer_data_arrived<true/*sync*/>(0);
// 从 lds 读取数据
#pragma unroll
for (int i = 0; i < WARP_NUM; ++i) {
int q_lds_load_offset = reinterpret_cast<size_t>(q_lds) + (i * 16 * 128) * sizeof(Element);
DS_READ_MATRIX_64x16_B8(q_lds_load_offset, q_reg[0][i * 2].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(q_lds_load_offset + 1024, q_reg[0][i * 2 + 1].i32x4, true/*transpose*/)
}
__builtin_amdgcn_sched_barrier(0);
// 接着读取剩下的 16x64
// =====================================================================================================
if (warp_id == 0) {
// [RTL bug] MLS 128B 请求指令使用 m_filter 需要限制起始地址和 stride 都是 128B 对齐, 否则在访问矩阵最后一行末尾时, 若地址跨越 64B, 一定概率跨越了页表, 导致 invalid address
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + ((WARP_NUM - 1) * 128 + 64) * sizeof(Element));
inline_matrix_load_128x16_b8_lds_trans<0, 1>(q_lds + 16384, q_srsrc, q_lds_write_bytes, 0);
// 等待数据写到 lds
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
// sync
flash::wait_all_warp_arrived();
// 每个 warp 读取 16x64 的内容
int q_lds_load_offset = reinterpret_cast<size_t>(q_lds + 16384) * sizeof(Element);
DS_READ_MATRIX_64x16_B8(q_lds_load_offset + 1024, q_reg[0][8].i32x4, true/*transpose*/)
// 同步, 等待数据写到寄存器, 同时防止 lds 被新的 mls 指令写入
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
template<int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_mla_tp8_prefetch_k_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_row_stride,
int max_seq_k_offset=0) {
int stage_id = 0;
// 准备 MLS resource 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
{
constexpr int k_loop = 0;
// lds 的写入地址
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
// global 随着 warp 的地址偏移
int warp_global_bytes; // = warp_id * 32 * k_row_stride * sizeof(Element);
// global 随着 k_loop 的地址偏移
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
// 计算边界
if constexpr (true) {
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
k_srsrc[3] += 0x40000;
}
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
#include "numeric_types.h"
#include "intrinsic.h"
#define CUDART_L2E_F 1.442695041F
// DataType: {vec2_Accum<ElementAccum>, vec_Accum<ElementAccum>}
template<int WARP_M, int kBlockK, int kHeadDimV, bool Is_dropout, typename ElementAccum, typename DataType=union_vec2_fp32/* vec2_Accum<ElementAccum> */, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_mla_epilugue_rescale_acco(
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT],
DataType lse[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_max[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_sum[WARP_M / (16 * M_MMAC_COUNT)],
const ElementAccum scale_softmax,
const ElementAccum rp_dropout) {
// Epilogue
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse[mi].f32[min_tile_m] = (sum == 0.f || sum != sum) ? -INFINITY : scores_max[mi].f32[min_tile_m] * scale_softmax + __logf(sum);
ElementAccum scale = Is_dropout ? inv_sum * rp_dropout: inv_sum;
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_n * 2 + min_tile_m;
} else {
mmac_id = min_tile_n;
}
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for(int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[pv_tile_id][mmac_id].f32[vec_id] *= scale;
}
#endif
}
}
}
}
}
}
template<int WARP_M, int kBlockK, int kHeadDimV, bool Is_dropout, typename ElementAccum, typename DataType=union_vec2_fp32/* vec2_Accum<ElementAccum> */, int M_MMAC_COUNT=2>
__forceinline__ __device__ void decode_dsa_epilugue_rescale_acco(
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT],
DataType lse[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_max[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_sum[WARP_M / (16 * M_MMAC_COUNT)],
const ElementAccum scale_softmax,
const ElementAccum rp_dropout,
float* attn_sink) {
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
// Epilogue
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse[mi].f32[min_tile_m] = (sum == 0.f || sum != sum) ? -INFINITY : scores_max[mi].f32[min_tile_m] * scale_softmax + __logf(sum);
float attn_sink_o_scale = 1.0f;
if (attn_sink != nullptr) {
float rAttn_sink = attn_sink[warp_id * 16 + tid % 16];
if (rAttn_sink == INFINITY) {
attn_sink_o_scale = 0.0f;
} else if ((lse[mi].f32[min_tile_m] != -INFINITY) && (lse[mi].f32[min_tile_m] != INFINITY)) {
float lse_exp2 = __builtin_amdgcn_exp2f(lse[mi].f32[min_tile_m] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
}
}
ElementAccum scale = inv_sum * attn_sink_o_scale;
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_n * 2 + min_tile_m;
} else {
mmac_id = min_tile_n;
}
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for(int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[pv_tile_id][mmac_id].f32[vec_id] *= scale;
}
#endif
}
}
}
}
}
}
template<int WARP_M, int kBlockK, int kHeadDimV, bool Is_dropout, typename ElementAccum, typename DataType=union_vec2_fp32/* vec2_Accum<ElementAccum> */, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_dsa_epilugue_rescale_acco(
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT],
DataType lse[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_max[WARP_M / (16 * M_MMAC_COUNT)],
DataType scores_sum[WARP_M / (16 * M_MMAC_COUNT)],
const ElementAccum scale_softmax,
const ElementAccum rp_dropout,
float* attn_sink) {
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
// Epilogue
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse[mi].f32[min_tile_m] = (sum == 0.f || sum != sum) ? -INFINITY : scores_max[mi].f32[min_tile_m] * scale_softmax + __logf(sum);
float attn_sink_o_scale = 1.0f;
if (attn_sink != nullptr) {
float rAttn_sink = attn_sink[warp_id * 16 + tid % 16];
if (rAttn_sink == INFINITY) {
attn_sink_o_scale = 0.0f;
} else if ((lse[mi].f32[min_tile_m] != -INFINITY) && (lse[mi].f32[min_tile_m] != INFINITY)) {
float lse_exp2 = __builtin_amdgcn_exp2f(lse[mi].f32[min_tile_m] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
}
}
ElementAccum scale = inv_sum * attn_sink_o_scale;
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_n * 2 + min_tile_m;
} else {
mmac_id = min_tile_n;
}
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for(int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[pv_tile_id][mmac_id].f32[vec_id] *= scale;
}
#endif
}
}
}
}
}
}
template<int WARP_M, bool Is_even_MN, bool SplitD, bool Is_Interleaved, typename ElementAccum, typename DataType=union_vec2_fp32/* vec2_Accum<ElementAccum> */, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_mla_epilogue_store_lse(
DataType lse[WARP_M / (16 * M_MMAC_COUNT)],
void *softmax_lse_ptr,
int row_offset_lse,
int warp_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit) {
ElementAccum * gLSE = reinterpret_cast<ElementAccum*>(softmax_lse_ptr) + row_offset_lse;
#if (DEBUG_LEVEL >= 1)
ElementAccum * scores_sum_ptr = reinterpret_cast<ElementAccum*>(scores_sum_ptr) + row_offset_lse;
ElementAccum * scores_max_ptr = reinterpret_cast<ElementAccum*>(scores_max_ptr) + row_offset_lse;
#endif
const bool write_lse = SplitD > 1 ? (lane_id >> 4) == 0 and headdim_split_id == 0: (lane_id >> 4) == 0;
if (write_lse) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = Is_Interleaved
? warp_id * WARP_M + mi * (16 * M_MMAC_COUNT) + (lane_id & 15) + min_tile_m * 16
: warp_id * WARP_M + mi * (16 * M_MMAC_COUNT) + (lane_id & 15) * 2 + min_tile_m;
if constexpr (Is_even_MN) {
gLSE[row] = lse[mi].f32[min_tile_m];
#if (DEBUG_LEVEL >= 1)
scores_sum_ptr[row] = scores_sum[mi].f32[min_tile_m];
scores_max_ptr[row] = scores_max[mi].f32[min_tile_m];
#endif
} else {
if (row < seqlen_q_limit) {
gLSE[row] = lse[mi].f32[min_tile_m];
#if (DEBUG_LEVEL >= 1)
scores_sum_ptr[row] = scores_sum[mi].f32[min_tile_m];
scores_max_ptr[row] = scores_max[mi].f32[min_tile_m];
#endif
}
}
}
}
}
}
template<int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, bool Is_even_MN, bool Is_Interleaved, bool TcpSwizzle, typename Element, typename ElementAccum, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_mla_epilogue_store_output(
Element *o_ptr,
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT],
int m_block,
int warp_id,
int lane_id,
int seqlen_o_stride,
int seqlen_q_limit) {
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
if constexpr (Is_Interleaved) {
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
union_vec2_f16x2<Element> acc_o_fp16[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT];
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2)
mmac_id = min_tile_m + min_tile_n * 2;
else
mmac_id = min_tile_n;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
acc_o_fp16[k_loop][mmac_id].f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[k_loop][mmac_id].f32x2[vec_index]);
}
ds_mpermute_kdim_for_mmac(acc_o_fp16[k_loop][mmac_id].f32);
}
}
}
#endif
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
flash::wait_lds_data_arrived<false>((kHeadDimV / kBlockK - k_loop - 1) * 2 * 2);
#endif
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int pv_tile_id = k_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_m + min_tile_n * 2;
} else {
mmac_id = min_tile_n;
}
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * (16 * M_MMAC_COUNT) + min_tile_m * 16 + pv_lane_seq_idx;
// prepare for store
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = acc_o_fp16[k_loop][mmac_id];
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = acc_o_fp16[k_loop][mmac_id];
}
#else
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
#endif
}
}
}
}
} // brace, to control vgpr usage
} else { // 仅支持LIT的部分
auto gO = prepare_for_buffer_load<kHeadDimV, Element>(o_ptr);
#pragma unroll
for(int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for(int warp_m_idx = 0; warp_m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++warp_m_idx) {
#pragma unroll
for(int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int vec_index = 0; vec_index < 4; ++vec_index) {
if constexpr (not Is_even_MN) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * (16 * M_MMAC_COUNT) + pv_lane_seq_idx + min_tile_m * 16; /*算的是 1 个 kBlockM 内在 seqlen_q 方向上的位置*/
int pv_global_addr = seqlen_q_offset * seqlen_o_stride + /*headdim 方向上的偏移*/k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
if(m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
if constexpr (M_MMAC_COUNT == 2)
o_ptr[pv_global_addr] = DownCast<ElementAccum, Element>(acc_o[k_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx][min_tile_m + min_tile_n * 2].f32[vec_index]);
else
o_ptr[pv_global_addr] = DownCast<ElementAccum, Element>(acc_o[k_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx][min_tile_n].f32[vec_index]);
}
}
}
else {
int tile32x32_id = k_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
int s_offset = k_loop * kBlockK;
int s_offset_constexpr = k_tile_idx * 32 + vec_index * 8; /*overflow for s_offset_constexpr*/
int v_offset = (warp_id * WARP_M + warp_m_idx * (16 * M_MMAC_COUNT) + pv_lane_seq_idx + min_tile_m * 16) * seqlen_o_stride + pv_lane_head_dim_idx * 2;
vec2_Element<Element> v_data;
// convert float -> bf16/fp16
if constexpr (std::is_same<Element, bhalf_t>::value) {
#if 1
v_data[0] = DownCast<ElementAccum, Element, true>(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data[1] = DownCast<ElementAccum, Element, true>(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
#else
v_data[0] = inlineasm_float2bfloat16_ushort_nonan(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data[1] = inlineasm_float2bfloat16_ushort_nonan(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
#endif
}
else if constexpr (std::is_same<Element, half_t>::value) {
#ifdef USE_CVT_PKRTZ_FP16_FP32
*(vec2_Element<Element>*)&v_data = DownCastPair<ElementAccum, Element>(
acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index],
acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]
);
#else
v_data[0] = DownCast<ElementAccum, Element>(acc_o[tile32x32_id][min_tile_m + 0 * 2].f32[vec_index]);
v_data[1] = DownCast<ElementAccum, Element>(acc_o[tile32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
#endif
}
// write to global memory
inline_buffer_store_dword<vec2_Element<Element>, 1>(v_data, v_offset, gO, s_offset, /* immediate integer */s_offset_constexpr);
}
}
}
}
}
} // brace, to control vgpr usage
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
#include "mla_qk_gemm_utils_mls_ds.h"
#include "static_switch.h"
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int* block_table,
int batch_stride,
int page_block_size,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int kHeadDimV_OPT = 256; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + block_table[index_topk_1/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_1 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + block_table[index_topk_2/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_2 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if ((k_loop & 7) == 0x0) {
int n_loop_ = n_loop;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + block_table[index_topk_1/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_1 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + block_table[index_topk_2/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_2 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + block_table[index_topk_1/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_1 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + block_table[index_topk_2/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_2 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
// DS
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
if constexpr (PREFETCH_K) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage(
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int page_block_size,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int kHeadDimV_OPT = 256; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if ((k_loop & 7) == 0x0) {
int n_loop_ = n_loop;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
// DS
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
if constexpr (PREFETCH_K) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64(
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int page_block_size,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int kHeadDimV_OPT = 256; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 32 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 32 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if ((k_loop & 7) == 0x0) {
int n_loop_ = n_loop;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 32 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 32 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 7) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
}
// DS
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
if constexpr (PREFETCH_K) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new(
Element* k_faker,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int page_block_size,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int kHeadDimV_OPT = 128; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int index_topk_1[2];
index_topk_1[0] = index_ptr[((n_loop_real * 64)) + (tid / 4)];
index_topk_1[1] = index_ptr[((n_loop_real * 64)) + (tid / 4) + 32];
int index_topk_2[2];
index_topk_2[0] = index_ptr[((n_loop_real * 64)) + (tid / 4) + 16];
index_topk_2[1] = index_ptr[((n_loop_real * 64)) + (tid / 4) + 48];
// int index_topk_2 = index_ptr[((n_loop_real * 64)) + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = tid % 4 * 4 + index_topk_1[0] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s = warp_id * 16;
int lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = tid % 4 * 4 + index_topk_2[0] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s_add = warp_id * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, g_offset_s_add, g_offset_v_add);
// __builtin_amdgcn_sched_barrier(0);
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v = tid % 4 * 4 + index_topk_1[0] * seqlen_v_stride * ELEMENT_BYTES / 4;
g_offset_s = n_load * WARP_NUM * 16 + warp_id * 16;
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v_add = tid % 4 * 4 + index_topk_2[0] * seqlen_v_stride * ELEMENT_BYTES / 4;
g_offset_s_add = n_load * WARP_NUM * 16 + warp_id * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, g_offset_s_add, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 3) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
}
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if ((k_loop & 3) == 0x0) {
int n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// index_topk_1 = index_ptr[k_loop / 12 * 32 + (n_loop_real * 64) + (tid / 4)];
// index_topk_2 = index_ptr[k_loop / 12 * 32 + (n_loop_real * 64) + (tid / 4) + 16];
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v = tid % 4 * 4 + index_topk_1[k_loop / 12] * seqlen_v_stride * ELEMENT_BYTES / 4;
g_offset_s = ((k_loop + 4) & 15) * 16 + warp_id * 16;
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v_add = tid % 4 * 4 + index_topk_2[k_loop / 12] * seqlen_v_stride * ELEMENT_BYTES / 4;
g_offset_s_add = ((k_loop + 4) & 15) * 16 + warp_id * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, g_offset_s_add, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// int index_topk_1 = index_ptr[n_loop_ * 32 + (n_loop_real * 64) + (tid / 4)];
// int index_topk_2 = index_ptr[n_loop_ * 32 + (n_loop_real * 64) + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = tid % 4 * 4 + index_topk_1[1] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s = n_load * WARP_NUM * 16 + warp_id * 16;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = tid % 4 * 4 + index_topk_2[1] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s_add = n_load * WARP_NUM * 16 + warp_id * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, g_offset_s_add, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 3) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
}
// DS
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if (k_loop == 4 || k_loop == 8) {
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// int index_topk_1 = index_ptr[32 + (n_loop_real * 64) + (tid / 4)];
// int index_topk_2 = index_ptr[32 + (n_loop_real * 64) + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = tid % 4 * 4 + index_topk_1[1] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s = ((k_loop + 4) & 15) * 16 + warp_id * 16;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = tid % 4 * 4 + index_topk_2[1] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s_add = ((k_loop + 4) & 15) * 16 + warp_id * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, g_offset_s_add, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
}
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
int abc[1];
int index_topk = index_ptr[(((n_loop_real+1) % 16) * 64) + warp_id * 16];
int offset_m = index_topk * seqlen_k_stride;
auto g_abc = (reinterpret_cast<uint64_t>(k_faker + offset_m));
inline_s_load_dword(abc[0], g_abc, 0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
// if constexpr (PREFETCH_K) {
// prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
// prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
// }
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_666(
Element* k_faker,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 64;
constexpr int READ_ONCE_COUNT = 16 * 32;
constexpr int kHeadDimV_OPT = 64; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 64 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int index_topk = index_ptr[(n_loop_real * 64) + warp_id * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2 = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16 + 32 * 64) * ELEMENT_BYTES / 4);
int g_offset_v = tid % 4 * 4 + index_topk * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s = 0 * WARP_N * ELEMENT_BYTES / 4 + 0;
int g_offset_s_2 = 0 * WARP_N * ELEMENT_BYTES / 4 + 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
int lds_stage_id = 1;
for(int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
{
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16 + 32 * 64) * ELEMENT_BYTES / 4);
g_offset_s = k_loop * WARP_N * ELEMENT_BYTES / 4;
g_offset_s_2 = k_loop * WARP_N * ELEMENT_BYTES / 4 + 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 32 * 32) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 32 * 32) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
int k_loop = kHeadDimV / kBlockN;
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 32 * 32) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 32 * 32) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 0;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = (k_loop - 1) * 2 + 1;
int v_tile_id = 4 + stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999(
Element* k_faker,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int page_block_size,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 16;
constexpr int READ_ONCE_COUNT = 16 * 32;
constexpr int kHeadDimV_OPT = 256; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 16 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 256 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[WARP_N / 32];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop_real * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop_real * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop_real * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop_real * 64) + 3 * 16 + (tid / 4)];
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2 = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
int index_block = index_topk[0] / page_block_size;
int index_offset = index_topk[0] - index_block * page_block_size;
int g_offset_v = tid % 4 * 4
+ (index_block * batch_stride + index_offset * seqlen_v_stride) * ELEMENT_BYTES / 4;
int g_offset_s = warp_id * 32 * ELEMENT_BYTES / 4;
int g_offset_s_2 = warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
int lds_stage_id = 1;
for(int total_loop = 1; total_loop < (kHeadDimV / kBlockN) * 4; ++total_loop) {
{
index_block = index_topk[total_loop / 2] / page_block_size;
index_offset = index_topk[total_loop / 2] - index_block * page_block_size;
g_offset_v = tid % 4 * 4
+ (index_block * batch_stride + index_offset * seqlen_v_stride) * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4;
g_offset_s_2 = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 0;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 1;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 2;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 3;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 4;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 5;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 6;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 7;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 0;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 1;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 2;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 3;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 4;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 5;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// int abc[1];
// int index_topk_qk = index_ptr[(((n_loop_real+1) % 16) * 64) + warp_id * 16];
// int offset_m = index_topk_qk * seqlen_k_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(k_faker + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 6;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 7;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_2(
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 16;
constexpr int READ_ONCE_COUNT = 32 * 16;
constexpr int kHeadDimV_OPT = 128; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
// static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (16 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 16 + n_loop_real * 64 + (tid / 4)];
// int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 16 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(lds_load_offset, v_reg[stage_id].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 3) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
}
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 16 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(lds_load_offset, v_reg[stage_id].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(1);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(n_loop - 1)/2][(n_loop - 1)%2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if ((k_loop & 3) == 0x0) {
int n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[(n_loop - 1) * 16 + k_loop / 12 * 16 + n_loop_real * 64 + (tid / 4)];
// int index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
}
}
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(0);
// last mmac
flash::raise_priority();
{
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(n_loop - 1)/2][(n_loop - 1)%2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 16 + n_loop_real * 64 + (tid / 4)];
// int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 16 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(lds_load_offset, v_reg[stage_id].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 3) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
}
// DS
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 16 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(lds_load_offset, v_reg[stage_id].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(1);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(n_loop - 1)/2][(n_loop - 1)%2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if (k_loop == 4 || k_loop == 8) {
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[48 + n_loop_real * 64 + (tid / 4)];
// int index_topk_2 = index_ptr[32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
}
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(0);
// last mmac
flash::raise_priority();
{
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(n_loop - 1)/2][(n_loop - 1)%2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
if constexpr (PREFETCH_K) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_1(
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int READ_ONCE_COUNT = 32 * 32;
constexpr int kHeadDimV_OPT = 128; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 32 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32) * 2];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int lds_stage_id = 1;
for (int n_loop = 1; n_loop < (kBlockK / WARP_K); ++n_loop) {
// prefetch same warpk, next 32x256 G2S
{
int n_load = 1;
int n_loop_ = n_loop - 1;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
int k_loop = 1;
stage_id ^= 1;
for (; k_loop < 4; ++k_loop) {
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// Wait for special headdim
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
int n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
k_loop++;
for (; k_loop < 8; ++k_loop) {
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// Wait for special headdim
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4)];
index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
k_loop++;
for (; k_loop < 12; ++k_loop) {
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// Wait for special headdim
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4)];
index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
k_loop++;
for (; k_loop < 16; ++k_loop) {
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// Wait for special headdim
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
n_loop_ = n_loop;
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4)];
index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
k_loop++;
stage_id ^= 1;
// Wait DS
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
{
constexpr int n_loop = kBlockK / WARP_K;
// MLS for special headdimV
{
constexpr int n_loop_ = n_loop - 1;
int n_load = 1;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id ^= 1;
int stage_id = 0;
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS*2); // [TODO]更早的预取
// DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
stage_id ^= 1;
for (int k_loop = 1; k_loop < (kHeadDimV / kBlockN); ++k_loop) {
// Wait for special headdim
if ((k_loop & 3) == 0x0) {
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
}
// DS
int lds_load_offset = v_lds_base + ((k_loop & 3) * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
flash::wait_lds_data_arrived<false>(3);
// MMAC
flash::raise_priority();
stage_id ^= 1;
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(2);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (STAGES == 2) ? k_loop - 1: k_loop;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
// MLS for special headdimV
if (k_loop == 4 || k_loop == 8) {
lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int index_topk_1 = index_ptr[32 + n_loop_real * 64 + (tid / 4)];
int index_topk_2 = index_ptr[32 + n_loop_real * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id ^= 1;
}
}
stage_id ^= 1;
flash::wait_lds_data_arrived<false>(1);
// last mmac
flash::raise_priority();
{
constexpr int min_tile_k = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::wait_lds_data_arrived<false>(0);
{
constexpr int min_tile_k = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 1; ++min_tile_m) {
int pv_tile_id = (kHeadDimV / kBlockN) - 1;
int v_tile_id = stage_id * 2 + min_tile_k;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[n_loop - 1][min_tile_k].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
}
flash::lower_priority();
}
// 预取Q K
if constexpr (PREFETCH_K) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512(
vec4_uint v_ptr,
Element* v_lds,
int warp_id,
int seqlen_v_stride,
int max_seq_kv_offset=0) {
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int kHeadDim_OPT = 256; // 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
int n_load = 0;
vec4_uint v_srsrc;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
v_srsrc[2] = seqlen_v_stride;
if constexpr (true) {
int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
}
int lds_stage_id = 0;
int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512_buffer_load(
vec4_uint v_ptr,
Element* v_lds,
int warp_id,
int seqlen_v_stride,
int* index_ptr,
int* block_table,
int batch_stride,
int n_loop,
int max_seq_kv_offset=0) {
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int kHeadDim_OPT = 256; // 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int tid = threadIdx.x % 64;
int index_topk_1 = index_ptr[n_loop * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + block_table[index_topk_1/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_1 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + block_table[index_topk_2/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk_2 % 128) * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage(
vec4_uint v_ptr,
Element* v_lds,
int warp_id,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop,
int max_seq_kv_offset=0) {
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int kHeadDim_OPT = 256; // 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int tid = threadIdx.x % 64;
int index_topk_1 = index_ptr[n_loop * 64 + (tid / 4)];
int index_topk_2 = index_ptr[n_loop * 64 + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64(
vec4_uint v_ptr,
Element* v_lds,
int warp_id,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop,
int max_seq_kv_offset=0) {
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 32;
constexpr int kHeadDim_OPT = 256; // 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int tid = threadIdx.x % 64;
int index_topk_1 = index_ptr[((n_loop * 64)) + (tid / 4)];
int index_topk_2 = index_ptr[((n_loop * 64)) + (tid / 4) + 16];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id * 16 + tid % 4 * 4 + index_topk_1 * seqlen_v_stride * ELEMENT_BYTES / 4;
int lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
int g_offset_v_add = warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane(((warp_id + WARP_NUM) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + (warp_id + WARP_NUM) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "mla_pv_gemm_utils_mls_ds.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int* block_table,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 2;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 准备 q,k 寄存器
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32)];
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int q_lds_base = reinterpret_cast<size_t>(q_lds);
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint q_srsrc;
vec4_uint k_srsrc;
q_srsrc[2] = __seqlen_q_stride;
if constexpr (Is_FlashMLA) {
k_srsrc[2] = seqlen_k_stride;
} else {
k_srsrc[2] = seqlen_v_stride;
}
q_srsrc[3] = 0;
k_srsrc[3] = 0;
int q_stage_id = 0;
int k_stage_id = 0;
if constexpr (STAGES == 2) {
q_stage_id ^= 1;
k_stage_id ^= 1;
}
{
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// k预取的标志位
int k_even = ((k_loop & 1) == 0x0) ? 1 : 0;
{
uint64_t q_base_addr;
int seqlen_q_stride;
int kloop_true;
if constexpr (Is_FlashMLA) {
q_srsrc[2] = __seqlen_q_stride;
q_base_addr = *(uint64_t*)&q_ptr;
seqlen_q_stride = __seqlen_q_stride;
kloop_true = k_loop;
} else {
q_srsrc[2] = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
q_base_addr = (k_loop >= 2) ? *(uint64_t*)&qv_ptr : *(uint64_t*)&q_ptr;
seqlen_q_stride = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
kloop_true = (k_loop >= 2) ? (k_loop - 2) : (k_loop);
}
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(q_base_addr + (kloop_true * kBlockK + warp_id * 16 * seqlen_q_stride) * ELEMENT_BYTES);
int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
if (k_even) {
k_stage_id ^= 1;
int index_topk = index_ptr[n_loop * 64 + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop) / 2) * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16 + ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + block_table[index_topk/128] * batch_stride_k * ELEMENT_BYTES / 4 + (index_topk % 128) * seqlen_k_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, 0, g_offset_v);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
// 不对称MLS指令
if (k_even) {
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS + K_LOAD_REQUESTS);
} else {
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS);
}
q_stage_id ^= 1;
// Q DS
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
constexpr int k_loop = kHeadDim / kBlockK;
constexpr int k_even = ((k_loop & 1) == 0x0) ? 1 : 0;
if constexpr (k_even) {
k_stage_id ^= 1;
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// Q DS
q_stage_id ^= 1;
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
// K DS
k_stage_id ^= 1;
int stage_id = 0;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, index_ptr,block_table, batch_stride_v, n_loop, max_seq_k_offset);
#else
#endif
}
} // qk_gemm
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 2;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 准备 q,k 寄存器
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32)];
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int q_lds_base = reinterpret_cast<size_t>(q_lds);
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint q_srsrc;
vec4_uint k_srsrc;
q_srsrc[2] = __seqlen_q_stride;
if constexpr (Is_FlashMLA) {
k_srsrc[2] = seqlen_k_stride;
} else {
k_srsrc[2] = seqlen_v_stride;
}
q_srsrc[3] = 0;
k_srsrc[3] = 0;
int q_stage_id = 0;
int k_stage_id = 0;
if constexpr (STAGES == 2) {
q_stage_id ^= 1;
k_stage_id ^= 1;
}
{
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// k预取的标志位
int k_even = ((k_loop & 1) == 0x0) ? 1 : 0;
{
uint64_t q_base_addr;
int seqlen_q_stride;
int kloop_true;
if constexpr (Is_FlashMLA) {
q_srsrc[2] = __seqlen_q_stride;
q_base_addr = *(uint64_t*)&q_ptr;
seqlen_q_stride = __seqlen_q_stride;
kloop_true = k_loop;
} else {
q_srsrc[2] = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
q_base_addr = (k_loop >= 2) ? *(uint64_t*)&qv_ptr : *(uint64_t*)&q_ptr;
seqlen_q_stride = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
kloop_true = (k_loop >= 2) ? (k_loop - 2) : (k_loop);
}
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(q_base_addr + (kloop_true * kBlockK + warp_id * 16 * seqlen_q_stride) * ELEMENT_BYTES);
int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
if (k_even) {
k_stage_id ^= 1;
int index_topk = index_ptr[n_loop * 64 + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = ((k_loop) / 2) * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16 + ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, 0, g_offset_v);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
// 不对称MLS指令
if (k_even) {
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS + K_LOAD_REQUESTS);
} else {
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS);
}
q_stage_id ^= 1;
// Q DS
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
constexpr int k_loop = kHeadDim / kBlockK;
constexpr int k_even = ((k_loop & 1) == 0x0) ? 1 : 0;
if constexpr (k_even) {
k_stage_id ^= 1;
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// Q DS
q_stage_id ^= 1;
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
// K DS
k_stage_id ^= 1;
int stage_id = 0;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + k_even * 32 * 64 + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, index_ptr, batch_stride_v, n_loop, max_seq_k_offset);
#else
#endif
}
} // qk_gemm
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 32 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 1;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 准备 q,k 寄存器
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32)];
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int q_lds_base = reinterpret_cast<size_t>(q_lds);
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint q_srsrc;
vec4_uint k_srsrc;
q_srsrc[2] = __seqlen_q_stride;
if constexpr (Is_FlashMLA) {
k_srsrc[2] = seqlen_k_stride;
} else {
k_srsrc[2] = seqlen_v_stride;
}
q_srsrc[3] = 0;
k_srsrc[3] = 0;
int q_stage_id = 0;
int k_stage_id = 0;
if constexpr (STAGES == 2) {
q_stage_id ^= 1;
k_stage_id ^= 1;
}
{
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
{
uint64_t q_base_addr;
int seqlen_q_stride;
int kloop_true;
if constexpr (Is_FlashMLA) {
q_srsrc[2] = __seqlen_q_stride;
q_base_addr = *(uint64_t*)&q_ptr;
seqlen_q_stride = __seqlen_q_stride;
kloop_true = k_loop;
} else {
q_srsrc[2] = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
q_base_addr = (k_loop >= 2) ? *(uint64_t*)&qv_ptr : *(uint64_t*)&q_ptr;
seqlen_q_stride = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
kloop_true = (k_loop >= 2) ? (k_loop - 2) : (k_loop);
}
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(q_base_addr + (kloop_true * kBlockK + warp_id * 16 * seqlen_q_stride) * ELEMENT_BYTES);
int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
int index_topk = index_ptr[n_loop * 64 + warp_id_n * 16 + (tid / 4)];
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = k_loop * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16 + ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, 0, g_offset_v);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS + K_LOAD_REQUESTS);
q_stage_id ^= 1;
// Q DS
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// Q DS
q_stage_id ^= 1;
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[0].f16, true);
}
// K DS
k_stage_id ^= 1;
int stage_id = 0;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, index_ptr, batch_stride_v, n_loop, max_seq_k_offset);
#else
#endif
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* v_faker,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 64) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 32 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 1;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint k_srsrc;
if constexpr (Is_FlashMLA) {
k_srsrc[2] = seqlen_k_stride;
} else {
k_srsrc[2] = seqlen_v_stride;
}
k_srsrc[3] = 0;
int k_stage_id = 0;
int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
int g_offset_s = ((kHeadDim / kBlockK) - 1) * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
if constexpr (STAGES == 2) {
k_stage_id ^= 1;
}
{
// int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
for(int k_loop = (kHeadDim / kBlockK) - 2; k_loop >= 0; --k_loop) {
{
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
g_offset_s = k_loop * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[k_loop+1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[k_loop+1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[k_loop+1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[k_loop+1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// K DS
k_stage_id ^= 1;
int stage_id = 0;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// int abc[1];
// int index_topk = index_ptr[(((n_loop+1) % 16) * 64) + warp_id * 16];
// int offset_m = index_topk * seqlen_v_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(v_faker + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
flash::lower_priority();
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
// prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, index_ptr, batch_stride_v, n_loop, max_seq_k_offset);
#else
#endif
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_666(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* v_faker,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 1;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint k_srsrc;
if constexpr (Is_FlashMLA) {
k_srsrc[2] = seqlen_k_stride;
} else {
k_srsrc[2] = seqlen_v_stride;
}
k_srsrc[3] = 0;
int k_stage_id = 0;
int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2 = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16 + 32 * 64) * ELEMENT_BYTES / 4);
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
int g_offset_s = ((kHeadDim / kBlockK) - 1) * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16;
int g_offset_s_2 = ((kHeadDim / kBlockK) - 1) * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16 + 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
if constexpr (STAGES == 2) {
k_stage_id ^= 1;
}
{
// int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
// g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
for(int k_loop = (kHeadDim / kBlockK) - 2; k_loop >= 0; --k_loop) {
{
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16 + 32 * 64) * ELEMENT_BYTES / 4);
g_offset_s = k_loop * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16;
g_offset_s_2 = k_loop * kHeadDim_OPT * ELEMENT_BYTES / 4 + warp_id_m * 16 + 16;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[(k_loop + 1) * 2 + 1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// K DS
k_stage_id ^= 1;
int stage_id = 0;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 0 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * WARP_N * kHeadDim_OPT + 1 * 32 * 32 + 32 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// int abc[2];
// int index_topk = index_ptr[(n_loop * 64)];
// int index_topk2 = index_ptr[(n_loop * 64) + 8];
// int offset_m = index_topk * seqlen_k_stride;
// int offset_m2 = index_topk2 * seqlen_k_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(v_faker + offset_m));
// auto g_abc2 = (reinterpret_cast<uint64_t>(v_faker + offset_m2));
// inline_s_load_dword(abc[0], g_abc, 0);
// inline_s_load_dword(abc[1], g_abc2, 0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[stage_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[stage_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_777(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* v_faker,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = 64;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
#pragma unroll 1
for(int i=3;i>=0;i--)
{
int k_stage_id = 0;
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2;
int index_block = index_topk[i] / page_block_size;
int index_offset = index_topk[i] - index_block * page_block_size;
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16
+ (index_block * batch_stride_k + index_offset * seqlen_k_stride) * ELEMENT_BYTES / 4;
int g_offset_s = 512 * ELEMENT_BYTES / 4 + warp_id * 16;
int g_offset_s_2;
flash::wait_all_warp_arrived();
if(warp_id < 2){
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
}
if constexpr (STAGES == 2) {
k_stage_id ^= 1;
}
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[16].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[17].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 0 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 0 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[8].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[9].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[10].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[11].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[12].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[13].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[14].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[15].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_buffer_data_arrived<true>(0);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[2].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[3].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[4].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[5].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[6].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[7].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* v_faker,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 512) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
index_topk[0] = (index_topk[0] == -1) ? 0 : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? 0 : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? 0 : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? 0 : index_topk[3];
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int k_stage_id = 0;
int stage_id;
int index_block = index_topk[3] / page_block_size;
int index_offset = index_topk[3] - index_block * page_block_size;
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16
+ (index_block * batch_stride_k + index_offset * seqlen_k_stride) * ELEMENT_BYTES / 4;
int lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
int g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16;
int g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
k_stage_id ^= 1;
// #pragma unroll 1
for(int i=6;i>=0;i--)
{
int score_id = (i + 1) >> 1;
index_block = index_topk[i / 2] / page_block_size;
index_offset = index_topk[i / 2] - index_block * page_block_size;
g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16
+ (index_block * batch_stride_k + index_offset * seqlen_k_stride) * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 2].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 3].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 4].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 5].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 6].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[score_id / 2][score_id % 2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[((i + 1) % 2) * 8 + 7].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[score_id / 2][score_id % 2].f32);
}
}
flash::lower_priority();
}
flash::wait_buffer_data_arrived<true>(0);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[2].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[3].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[4].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[5].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[6].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[0][0].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[7].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[0][0].f32);
}
}
flash::lower_priority();
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* v_faker,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
index_topk[0] = (index_topk[0] == -1) ? 0 : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? 0 : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? 0 : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? 0 : index_topk[3];
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
int k_stage_id = 0;
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2;
int index_block = index_topk[3] / page_block_size;
int index_offset = index_topk[3] - index_block * page_block_size;
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16
+ (index_block * batch_stride_k + index_offset * seqlen_k_stride) * ELEMENT_BYTES / 4;
int g_offset_s = 512 * ELEMENT_BYTES / 4 + warp_id * 16;
int g_offset_s_2;
flash::wait_all_warp_arrived();
if(warp_id<2)
{
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
}
if constexpr (STAGES == 2) {
k_stage_id ^= 1;
}
{
#pragma unroll 1
for(int i = 3;i >= 0;i--)
{
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[16].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[17].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 0 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 0 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[8].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[9].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[10].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[11].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[12].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[13].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[14].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[15].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
if(i != 0){
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
index_block = index_topk[i - 1] / page_block_size;
index_offset = index_topk[i - 1] - index_block * page_block_size;
g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16
+ (index_block * batch_stride_k + index_offset * seqlen_k_stride) * ELEMENT_BYTES / 4;
g_offset_s = 512 * ELEMENT_BYTES / 4 + warp_id * 16;
flash::wait_all_warp_arrived();
if(warp_id<2){
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
flash::wait_buffer_data_arrived<true>(1);
}
else{
flash::wait_buffer_data_arrived<true>(0);
}
}
else{
flash::wait_buffer_data_arrived<true>(0);
}
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[2].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[3].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[4].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[5].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[6].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[7].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
flash::lower_priority();
}
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_q(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 64) * (kHeadDim / kBlockK)],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int max_seq_q_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = 32;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 1;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
// 计算 q_lds,k_lds 的起始偏移量
int q_lds_base = reinterpret_cast<size_t>(q_lds);
int tid = threadIdx.x % 64;
// MLS
vec4_uint q_srsrc;
vec4_uint q_srsrc2;
q_srsrc[2] = __seqlen_q_stride;
q_srsrc2[2] = __seqlen_q_stride;
q_srsrc[3] = 0;
q_srsrc2[3] = 0;
int q_stage_id = 0;
if constexpr (STAGES == 2) {
q_stage_id ^= 1;
}
{
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
{
uint64_t q_base_addr;
int seqlen_q_stride;
int kloop_true;
if constexpr (Is_FlashMLA) {
q_srsrc[2] = __seqlen_q_stride;
q_base_addr = *(uint64_t*)&q_ptr;
seqlen_q_stride = __seqlen_q_stride;
kloop_true = k_loop;
} else {
q_srsrc[2] = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
q_base_addr = (k_loop >= 2) ? *(uint64_t*)&qv_ptr : *(uint64_t*)&q_ptr;
seqlen_q_stride = (k_loop >= 2) ? seqlen_qv_stride : __seqlen_q_stride;
kloop_true = (k_loop >= 2) ? (k_loop - 2) : (k_loop);
}
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(q_base_addr + (kloop_true * kBlockK + warp_id * 16 * seqlen_q_stride) * ELEMENT_BYTES);
*(uint64_t*)&q_srsrc2 = VA_LIMIT_BITS(q_base_addr + (kloop_true * kBlockK + warp_id * 16 * seqlen_q_stride + 32) * ELEMENT_BYTES);
// int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64) * ELEMENT_BYTES;
int lds_offset2 = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64 + 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc2, lds_offset2, 0);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(Q_LOAD_REQUESTS);
q_stage_id ^= 1;
// Q DS
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64) * ELEMENT_BYTES;
int q_lds_load_offset2 = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64 + 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[(k_loop-1) * 2 + 0].f16, true);
DS_READ_MATRIX_32X16_B16(q_lds_load_offset2, q_reg[(k_loop-1) * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(0);
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
// Q DS
q_stage_id ^= 1;
{
int q_lds_load_offset = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64) * ELEMENT_BYTES;
int q_lds_load_offset2 = q_lds_base + (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 64 + 16 * 32) * ELEMENT_BYTES;
DS_READ_MATRIX_32X16_B16(q_lds_load_offset, q_reg[(kHeadDim / kBlockK - 1) * 2 + 0].f16, true);
DS_READ_MATRIX_32X16_B16(q_lds_load_offset2, q_reg[(kHeadDim / kBlockK - 1) * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(0);
}
}
#pragma once
#include "intrinsic_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_q_to_lds_mls_ds_576_512(
vec4_uint q_ptr,
Element* q_lds,
int warp_id,
int seqlen_q_stride,
int max_seq_q_offset=0) {
// 编译期可知变量
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
// LDS 起始地址
int q_lds_base = reinterpret_cast<size_t>(q_lds);
// MLS
vec4_uint q_srsrc;
vec4_uint q_srsrc2;
q_srsrc[2] = seqlen_q_stride;
q_srsrc[3] = 0;
q_srsrc2[2] = seqlen_q_stride;
q_srsrc2[3] = 0;
int stage_id = 0;
{
int k_loop = 0;
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_ptr + (k_loop * kBlockK + warp_id * 16 * seqlen_q_stride) * ELEMENT_BYTES);
*(uint64_t*)&q_srsrc2 = VA_LIMIT_BITS(*(uint64_t*)&q_ptr + (k_loop * kBlockK + warp_id * 16 * seqlen_q_stride + 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
// }
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 16 * 64) * ELEMENT_BYTES;
int lds_offset2 = (stage_id * kBlockM * kBlockK + warp_id * 16 * 64 + 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc2, lds_offset2, 0);
}
}
template<int kHeadDim, int kBlockN, int kBlockK, int WARP_NUM, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512(
vec4_uint k_ptr,
Element* k_lds,
int warp_id,
int seqlen_k_stride,
int max_seq_k_offset=0) {
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 2;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
int stage_id = 0;
int n_loop = 0;
int k_loop = 0;
// MLS
vec4_uint k_srsrc;
k_srsrc[2] = seqlen_k_stride;
if constexpr (true) {
int nm_filter = inline_min_max<0,16>(n_loop * WARP_N + 16 * warp_id_n + 16 - max_seq_k_offset);
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (n_loop * WARP_N * seqlen_k_stride + warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + k_loop * 32 * WARP_NUM_M) * ELEMENT_BYTES);
k_srsrc[3] = max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter << 8;
}
int lds_offset = (stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
template<int kHeadDim, int kBlockN, int kBlockK, int WARP_NUM, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512_buffer_load(
vec4_uint k_ptr,
Element* k_lds,
int warp_id,
int seqlen_k_stride,
int* index_ptr,
int* block_table,
int batch_stride,
int n_loop,
int max_seq_k_offset=0) {
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 2;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
int tid = threadIdx.x % 64;
int stage_id = 0;
int k_loop = 0;
int index_topk = index_ptr[n_loop * 64 + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id_m * 16 + ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + block_table[index_topk/128] * batch_stride * ELEMENT_BYTES / 4 + (index_topk % 128) * seqlen_k_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, 0, g_offset_v);
}
template<int kHeadDim, int kBlockN, int kBlockK, int WARP_NUM, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512_buffer_load_nopage(
vec4_uint k_ptr,
Element* k_lds,
int warp_id,
int seqlen_k_stride,
int* index_ptr,
int batch_stride,
int n_loop,
int max_seq_k_offset=0) {
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int ELEMENT_BYTES = sizeof(Element);
constexpr int WARP_NUM_M = 2;
constexpr int WARP_NUM_N = 4;
int warp_id_m = warp_id / WARP_NUM_N;
int warp_id_n = warp_id % WARP_NUM_N;
int tid = threadIdx.x % 64;
int stage_id = 0;
int k_loop = 0;
int index_topk = index_ptr[(n_loop * 64) & 1023 + warp_id_n * 16 + (tid / 4)];
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = warp_id_m * 16 + ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, 0, g_offset_v);
}
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include "fwd/utils.h"
using namespace flash;
template<int THREADS, typename DataType=union_vec2_fp32>
struct PrefillMlaAllreduce {
static_assert(THREADS == 64);
template<typename Operator>
static __device__ inline DataType run(DataType x, Operator &op) {
DataType res;
if constexpr (std::is_same<DataType, union_vec2_fp32>::value) {
if constexpr (std::is_same<Operator, SumOp<float> >::value) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 16);
res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
res.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 16);
#endif
}
else if constexpr (std::is_same<Operator, MaxOp<float> >::value) {
x.f32[0] = op(x.f32[0], __shfl_xor_tmp(x.f32[0], 32));
x.f32[1] = op(x.f32[1], __shfl_xor_tmp(x.f32[1], 32));
res.f32[0] = op(x.f32[0], __shfl_xor_tmp(x.f32[0], 16));
res.f32[1] = op(x.f32[1], __shfl_xor_tmp(x.f32[1], 16));
}
} else { // union_vec_fp32 f32
if constexpr (std::is_same<Operator, SumOp<float> >::value) {
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
res.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 16);
}
else if constexpr (std::is_same<Operator, MaxOp<float> >::value) {
x.f32[0] = op(x.f32[0], __shfl_xor_tmp(x.f32[0], 32));
res.f32[0] = op(x.f32[0], __shfl_xor_tmp(x.f32[0], 16));
}
}
return res;
}
};
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__device__ inline void prefill_mla_thread_reduce_max(const DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = -INFINITY; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2)
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
else
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2)
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
else
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx]);
}
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
summary[m_idx * 2].u64 = 0x0;
} else {
summary[m_idx * 2].f32[0] = 0x0;
}
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
if constexpr (M_MMAC_COUNT == 2){
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
} else {
summary[m_idx * 2].f32[0] = summary[m_idx * 2].f32[0] + tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = 0; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2) {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
} else {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx]);
}
}
}
}
}
#endif
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
} else {
summary_cur[m_idx * 2].f32[0] = summary[m_idx * 2].f32[0];
}
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
} else {
summary_cur[m_idx * 2].f32[0] = summary_cur[m_idx * 2].f32[0] + tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2) {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
} else {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 16)][min_tile_n].f32[vec_idx]);
}
}
}
}
}
#endif
}
}
}
template<typename Operator, typename DataType, int WARP_M, int M_MMAC_COUNT=2>
__device__ inline void prefill_mla_quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); mi++) {
dst[mi] = PrefillMlaAllreduce<64, DataType>::run(src[mi], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__device__ inline void prefill_mla_reduce_(const DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if constexpr (OpType == 0) { // sum
if constexpr (zero_init == true) {
prefill_mla_thread_reduce_sum<true, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
prefill_mla_quad_allreduce_<Operator, DataType1, WARP_M, M_MMAC_COUNT>(summary, summary, op);
} else {
prefill_mla_thread_reduce_sum<false, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
prefill_mla_quad_allreduce_<Operator, DataType1, WARP_M, M_MMAC_COUNT>(summary_cur, summary_cur, op);
}
} else if constexpr (OpType == 1) { // max
if constexpr (zero_init == true) {
prefill_mla_thread_reduce_max<true, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
prefill_mla_quad_allreduce_<Operator, DataType1, WARP_M, M_MMAC_COUNT>(summary, summary, op);
} else {
prefill_mla_thread_reduce_max<false, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
prefill_mla_quad_allreduce_<Operator, DataType1, WARP_M, M_MMAC_COUNT>(summary_cur, summary_cur, op);
}
}
}
// zero_init==true, max is current max_score, max_cur=nullptr
// zero_init==false, max is prev max_score, max_cur!=nullptr
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__device__ inline void reduce_max(const DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if constexpr (zero_init == true) {
prefill_mla_reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op);
} else {
prefill_mla_reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__device__ inline void reduce_sum(DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if constexpr (zero_init == true) {
prefill_mla_reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op);
} else {
prefill_mla_reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], const DataType1 *max, const float scale) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const float max_scaled = (max[mi * 2].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi * 2].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
__float2 neg_max_scaled_pair = {-max_scaled, -max_scaled};
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_n * 2 + min_tile_m;
} else {
mmac_id = min_tile_n;
}
int qk_tile_id = mi + ni * (WARP_M / (16 * M_MMAC_COUNT));
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
}
#else
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1, int K/*head_dim_v*/, int kBlockK, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_M / (16 * M_MMAC_COUNT))][2 * M_MMAC_COUNT], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT], float softmax_scale_log2) {
if constexpr (Is_first) {
reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max);
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, softmax_scale_log2);
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum);
} else {
DataType1 scores_max_cur[(WARP_M / (16 * M_MMAC_COUNT))];
reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi * 2].f32[min_tile_m]
: (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi * 2].f32[min_tile_m] *= scores_scale;
__float2 scores_scale_pair = {scores_scale, scores_scale};
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + mi + ni * (WARP_M / (16 * M_MMAC_COUNT));
int mmac_id;
if constexpr (M_MMAC_COUNT == 2) {
mmac_id = min_tile_n * 2 + min_tile_m;
} else {
mmac_id = min_tile_n;
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
}
#else
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
acc_o[pv_tile_id][mmac_id].f32[vec_idx] *= scores_scale;
}
#endif
}
}
}
}
}
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max_cur, softmax_scale_log2);
DataType1 scores_sum_cur[(WARP_M / (16 * M_MMAC_COUNT))];
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
if constexpr (M_MMAC_COUNT == 2) {
scores_sum_cur[mi].u64 = 0x0;
} else {
scores_sum_cur[mi].f32[0] = 0x0;
}
}
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum_cur);
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
} else {
scores_sum[mi].f32[0] = scores_sum[mi].f32[0] + scores_sum_cur[mi].f32[0];
}
#else // for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results
if constexpr (M_MMAC_COUNT == 2) {
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
} else {
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
}
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__))
if constexpr (M_MMAC_COUNT == 2) {
inlineasm_fa_v_mov_b64(
scores_max[mi].u64,
scores_max_cur[mi].u64
);
} else {
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
}
#else
if constexpr (M_MMAC_COUNT == 2) {
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
} else {
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
}
#endif
}
}
};
// #define USE_CVT_PKRTZ_FP16_FP32
template <int WARP_M, int WARP_N, typename Element, typename ElementAccum, int M_MMAC_COUNT=2>
inline __device__ void prefill_mla_convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], union_vec4_fp32 s_reg[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT]) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
} else {
p_reg[n_idx * (WARP_M / 16) + m_idx][0].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 16) + m_idx][0].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 16) + m_idx][1].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 16) + m_idx][1].f32x2[min_tile_k]);
}
#else
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
inline __device__ void prefill_mla_apply_mask_gfx938(DataType tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
if constexpr (M_MMAC_COUNT == 2) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
} else {
tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = (col_idx >= max_seqlen_k)?-INFINITY:tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
// }
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
inline __device__ void decode_dsa_apply_mask_gfx938(DataType tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], int* index_ptr,
const int col_idx_offset_ = 0,int real_topk = 512) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
if constexpr (M_MMAC_COUNT == 2) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
} else {
tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = ((col_idx >= real_topk) || (index_ptr[col_idx % 1024] == -1))?-INFINITY:tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
// }
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
inline __device__ void prefill_dsa_apply_mask_gfx938(DataType tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], int* index_ptr,
const int col_idx_offset_ = 0,int real_topk = 512) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
if constexpr (M_MMAC_COUNT == 2) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
} else {
tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = ((col_idx >= real_topk) || (index_ptr[col_idx % 1024] == -1))?-INFINITY:tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
// }
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_mla_apply_mask_causal_gfx938(DataType tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
const int row_idx_base = row_idx_offset + mi * (16 * M_MMAC_COUNT);
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
if constexpr (M_MMAC_COUNT == 2) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
} else {
tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT=2>
__forceinline__ __device__ void prefill_mla_apply_mtp_mask_causal_gfx938(DataType tensor[(WARP_M / (16 * M_MMAC_COUNT)) * (WARP_N / 32)][2 * M_MMAC_COUNT], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) + max_seqlen_k - max_seqlen_q;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
const int row_idx_base = row_idx_offset + mi * (16 * M_MMAC_COUNT);
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
if constexpr (M_MMAC_COUNT == 2) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > row_idx) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
} else {
tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = (col_idx > row_idx) ? -INFINITY: tensor[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
}
}
}
}
template<typename DataType, int kBlockN, int WARP_M, int WARP_NUM>
__forceinline__ __device__ void flashmla_apply_mtp_mask_causal_gfx938(
DataType s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q,
const int ngroups,
const int mtp) {
const int lane_id = threadIdx.x & 63;
constexpr int mi = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int row_idx = row_idx_offset_ + (lane_id & 15);
int row_in_mtp = row_idx / ngroups;
int col_idx_limit_right = min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
#pragma unroll
for (int ni = 0; ni < kBlockN / 32; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_offset_ + ni * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + vec_idx; /*BMZ vec_idx * 4 + (lane_id >> 4) */
s_reg[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: s_reg[mi + ni * (WARP_M / 16)][min_tile_n].f32[vec_idx];
}
}
}
}
}
template <bool HasWSLeft=true, typename DataType, int WARP_M, int WARP_N>
inline __device__ void prefill_mla_apply_mask_local_gfx938(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q,
const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ?
-INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void prefill_mla_apply_alibi_gfx938(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, float g_alibi) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] += g_alibi * (col_idx - row_idx);
}
}
}
}
}
}
#pragma once
#include "numeric_types.h"
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Interleave2, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_tp8_epilogue_store_output_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int o_row_stride = params.o_row_stride;
const int64_t row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * params.b * params.o_batch_stride
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += 4) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
// which 32x32 tile
int tile_32x32_id = k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// index along seqlen_q dimension
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + pv_lane_seq_idx + min_tile_m * 16;
if (seqlen_q_idx < params.seqlen_q) {
if constexpr (Interleave2) {
/*contiguous 64 bytes storation*/
union_vec4_f16x2<SplitkvAccumType> v_data;
v_data.f16x2[0 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[0], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[0]);
v_data.f16x2[1 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[1], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[1]);
v_data.f16x2[0 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[2], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[2]);
v_data.f16x2[1 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[3], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[3]);
int pv_global_addr = seqlen_q_idx * o_row_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 8;
*(vec4_fp32*)(o_ptr + pv_global_addr) = v_data.f32;
} else {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
union_vec2_f16x2<SplitkvAccumType> data;
int mmac_id = min_tile_m + min_tile_n * 2;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
data.f16x2[vec_index] = DownCastPair<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][mmac_id].f32x2[vec_index]);
}
int pv_global_addr = seqlen_q_idx * o_row_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 4 + min_tile_n * 16;
*(union_vec2_f16x2<SplitkvAccumType>*)(o_ptr + pv_global_addr) = data;
}
}
}
}
}
}
}
}
\ No newline at end of file
#pragma once
#include "intrinsic_mls_ds.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initialization(
vec4_uint q_addr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset,
vec2_Accum<ElementAccum> scores_max[WARP_M / 32],
vec2_Accum<ElementAccum> scores_sum[WARP_M / 32],
vec4_Accum<ElementAccum> acc_o[kHeadDimV / kBlockK][4]) {
flash::wait_all_warp_arrived();
// prepare mls buffer resource registers
vec4_uint q_srsrc;
q_srsrc[2] = query_seqlen_stride;
q_srsrc[3] = 0;
// total 16x576 f16s
// 16x128 f16s per wave first
constexpr int LOAD = 4;
constexpr int block32x16_bytes = 32 * 16 * sizeof(Element);
#pragma unroll
for (int load_id = 0; load_id < LOAD; ++load_id) {
// lds address
int lds_offset_bytes = (load_id * WARP_NUM + warp_id) * block32x16_bytes;
// global offset
int q_warp_offset = (load_id * WARP_NUM + warp_id) * 32;
// compute global address
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
// matrix load
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
// insert valus in def-use
attention_initialize<kHeadDimV / kBlockK, WARP_M / 32, kBlockK / 32, M_MMAC_COUNT, ElementAccum>(scores_max, scores_sum, acc_o);
// fetch data from lds, from MID-th blocks
const int MID = 1;
#pragma unroll
for (int load_id = 0; load_id < MID; ++load_id) {
// wait global data written to lds
flash::wait_buffer_data_arrived<true/*sync*/>(LOAD - load_id - 1);
#pragma unroll
for (int i = 0; i < WARP_NUM; ++i) {
DS_READ_MATRIX_32X16_B16(load_id * WARP_NUM * block32x16_bytes + i * block32x16_bytes, q_reg[(load_id * 4 + i) * 2].f16, true);
}
}
// -------------------------------------------------------------------
// prefetch rest 16x64 loads
// 16x32 f16s 0-1 wave later
int lds_offset_bytes = (LOAD * WARP_NUM + warp_id) * block32x16_bytes;
int real_warp_id = warp_id >= 2 ? 0: warp_id;
int q_warp_offset = (LOAD * WARP_NUM + real_warp_id) * 32;
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
// continue from MID
#pragma unroll
for (int load_id = MID; load_id < LOAD; ++load_id) {
// wait global data written to lds
flash::wait_buffer_data_arrived<true/*sync*/>(LOAD - load_id - 1 + MID);
#pragma unroll
for (int i = 0; i < WARP_NUM; ++i) {
DS_READ_MATRIX_32X16_B16(load_id * WARP_NUM * block32x16_bytes + i * block32x16_bytes, q_reg[(load_id * 4 + i) * 2].f16, true);
}
}
// wait global data written to lds
flash::wait_buffer_data_arrived<true/*sync*/>(0);
// write last data into registers
DS_READ_MATRIX_32X16_B16((LOAD * WARP_NUM + 0) * block32x16_bytes, q_reg[(16 + 0) * 2].f16, true);
DS_READ_MATRIX_32X16_B16((LOAD * WARP_NUM + 1) * block32x16_bytes, q_reg[(16 + 1) * 2].f16, true);
// wait all data written to registers
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
\ No newline at end of file
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void mla_acco_reduce(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
constexpr int kBlockK = K_WARP_COUNT * 32;
// when REUSE_KV not in templated, compute max reuse times
int EVEN_REUSE_KV_TIMES = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2: ((seqlen_q + 1) / 2) * 2;
int HALF_REUSE_KV_TIMES = EVEN_REUSE_KV_TIMES >> 1;
int q_seq_idx = (lane_id & 15);
if (q_seq_idx < HALF_REUSE_KV_TIMES) { // 除以 2, 是因为每个线程都会储存两行的数据, seq 方向上是 0,0,1,1,2,2,3,3,4,4,....,15,15
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
// ####################################################################################################################################################
// 4 个 wave 分别把自己负责的 acc_o 计算结果写到 LDS 中
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int lds_offset = (warp_id * EVEN_REUSE_KV_TIMES + q_seq_idx * 2 + min_tile_m) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4/*0~3*/) * 4/*0~15*/;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32;
}
}
}
__syncthreads();
// ####################################################################################################################################################
// 在 lds 中求和, 把 4 个 wave 写的 acc_o 的数据加起来
// 如果恰好是 4 个 wave, 则 4 个 wave 一起参与到 lds 操作, 每个 wave 操作 4 个元素中的一个
if constexpr (WARP_NUM == 4) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx * 2 + min_tile_m) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + warp_id; // 之前是一次性写了 4 个 Half 到 lds, 现在 4 个 wave 分别处理这 4 个位置的 acc_o reduce
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
// 不是恰好 4 个 wave, 则把 wave 0 单独拎出来做 lds reduce 操作
else if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
int lds_offset = (q_seq_idx * 2 + min_tile_m) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + vec_idx;
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
}
}
__syncthreads();
// ####################################################################################################################################################
// 每个 wave 都从 LDS 获取最终的求和结果
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx * 2 + min_tile_m) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4;
acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
__syncthreads();
}
}
}
\ No newline at end of file
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, int Padding, typename ElementAccum>
__forceinline__ __device__ void mla_acco_reduce_tile16x32(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
constexpr int PREFETCH = WARP_NUM;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += PREFETCH) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int prefetch = 0; prefetch < PREFETCH; ++prefetch) {
vec4_fp32 f32x4 = acc_o[k_loop + prefetch][min_tile_n * 2].f32;
int lds_write_offset = warp_id * 2048 + prefetch * 2 * 16 * 16 + min_tile_n * 16 * 16;
lds_write_offset = reinterpret_cast<size_t>(acc_o_lds + lds_write_offset + lane_id * 4);
inlineasm_ds_write_b128(lds_write_offset, f32x4);
}
}
union_vec4_fp32 data[2][WARP_NUM];
constexpr int ds_bursts = PREFETCH;
{
constexpr int min_tile_n = 0;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 1;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH + ds_bursts);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 0;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor + ds_bursts);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
{
constexpr int min_tile_n = 1;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
flash::wait_all_warp_arrived();
}
}
\ No newline at end of file
#include "numeric_types.h"
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void mla_epilugue_rescale_acco(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT]) {
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
__float2 scale_pair = {inv_sum, inv_sum};
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for (int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].f32[vec_id] *= inv_sum;
}
#endif
}
}
}
}
}
}
template<bool Split, bool Is_16x32, int M_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void mla_tp8_epilogue_store_softmax_lse(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum *softmax_lse_ptr,
ElementAccum scale_softmax,
int warp_id,
int thread_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit
) {
if constexpr (Split) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = Is_16x32
? mi * 32 + lane_id/*equal to lane_id & 15*/ + min_tile_m * 16
: warp_id * M_WARP_COUNT * 32 + mi * 32 + thread_id * 2 + min_tile_m;
if (row < seqlen_q_limit) {
softmax_lse_ptr[row] = scores_max[mi].f32[min_tile_m] * scale_softmax + __logf(scores_sum[mi].f32[min_tile_m]);
}
}
}
}
}
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_epilogue_store_output(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * params.b * params.o_batch_stride
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
auto gO = prepare_for_buffer_load<kHeadDimV, SplitkvAccumType, false/*USE_CACHE_SWIZZLE*/>(o_ptr);
int pv_lane_seq_idx = (lane_id & 15);
int pv_lane_head_dim_idx = (lane_id >> 4);
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
// 获取第几个 32x32 tile
int tile_32x32_id = k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// 当前 32x32 tile 的第几个 mmac
int mmac_id = min_tile_m + min_tile_n * 2;
// seqlen_q 方向上的坐标
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + pv_lane_seq_idx * 2 + min_tile_m;
if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4%
int vec_index = warp_id;
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
if (seqlen_q_idx < params.seqlen_q) {
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
} else { // non-4-waves should use this, but lead to performance drop when 4 waves per SIMD
#pragma unroll
for (int vec_index = 0; vec_index < 4; ++vec_index) {
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
if (seqlen_q_idx < params.seqlen_q) {
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
}
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)");
}
\ No newline at end of file
#include "numeric_types.h"
template<bool Split, int M_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void mla_epilogue_store_max_sum_tile16x32(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum *scores_max_ptr,
ElementAccum *scores_sum_ptr,
ElementAccum scale_softmax,
int warp_id,
int thread_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit
) {
#ifdef FA_DEBUG_SUM_MAX
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
if (headdim_split_id == 0) { // 因为 split-D 使用同样的 QK, 计算得到同样的 scores_sum/scores_max 会写多遍, 可能会有数据冲突, 所以强制只写一遍
if (thread_id < 16) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = /*warp_id * WARP_M + */mi * 32 + lane_id/*equal to lane_id & 15*/ + min_tile_m * 16;
if (row < seqlen_q_limit) {
scores_sum_ptr[row] = scores_sum[mi].f32[min_tile_m];
scores_max_ptr[row] = scores_max[mi].f32[min_tile_m] * scale_softmax;
}
}
}
}
}
}
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_epilogue_store_output_tile16x32(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * params.b * params.o_batch_stride
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = (lane_id & 15);
int pv_lane_head_dim_idx = (lane_id >> 4);
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += WARP_NUM) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
int tile_32x32_id = k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + pv_lane_seq_idx + min_tile_m * 16;
if (seqlen_q_idx < params.seqlen_q) {
#pragma unroll
for (int vec_index = 0; vec_index < 4; ++vec_index) {
vec2_Element<SplitkvAccumType> data;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_m + min_tile_n * 2;
data[min_tile_n] = DownCast<ElementAccum, SplitkvAccumType, true>(acc_o[tile_32x32_id][mmac_id].f32[vec_index]);
}
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2;
*(vec2_Element<SplitkvAccumType>*)(o_ptr + pv_global_addr) = data;
}
}
}
}
}
}
}
#include "numeric_types.h"
#include "intrinsic.h"
#include "wait.h"
#include "flash.h"
using namespace flash;
template<int WARP_M, int kHeadDimVSplit, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_initialize(
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
vec4_Accum<ElementAccum> acc_o[WARP_M / 16][kHeadDimVSplit / 16]
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max[m_idx] = -INFINITY;
scores_sum[m_idx] = 0.f;
}
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit / 16; ++pv_tile) {
acc_o[m_idx][pv_tile].b64[0] = 0x0;
acc_o[m_idx][pv_tile].b64[1] = 0x0;
}
}
}
template<int kBlockM, int WARP_M, int WARP_NUM, typename Element>
__forceinline__ __device__ void mla_prefix_prefill_fetch_q_to_vgpr(
union_vec4_f16x2<Element> qv_regs[WARP_M / 16][8],
union_vec4_f16x2<Element> q_regs[WARP_M / 16],
Element* qv_ptr,
Element* q_ptr,
int m_block,
int warp_id_row,
int warp_id_col,
int lane_id,
int qv_row_stride,
int q_row_stride,
int actual_seqlen_q
) {
constexpr bool IS_8_WAVES = WARP_NUM == 8;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int load_loop = 0; load_loop < 8; ++load_loop) {
int qv_row = min(actual_seqlen_q - 1 - m_block * kBlockM, m_idx * (IS_8_WAVES ? 64: WARP_M) + warp_id_row * 16 + (lane_id & 15));
int qv_col = (lane_id >> 4) * 8 + warp_id_col * 32 + load_loop * 64;
int qv_buffer_offset = qv_row * qv_row_stride + qv_col;
qv_regs[m_idx][load_loop] = *(union_vec4_f16x2<Element>*)(qv_ptr + qv_buffer_offset);
}
}
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int q_row = min(actual_seqlen_q - 1 - m_block * kBlockM, m_idx * (IS_8_WAVES ? 64: WARP_M) + warp_id_row * 16 + (lane_id & 15));
int q_col = (lane_id >> 4) * 8 + warp_id_col * 32;
int q_buffer_offset = q_row * q_row_stride + q_col;
q_regs[m_idx] = *(union_vec4_f16x2<Element>*)(q_ptr + q_buffer_offset);
}
}
template<int kBlockN, int WARP_NUM, typename Element>
__forceinline__ __device__ void mla_prefix_prefill_prefetch_k_rope_to_lds(
Element* k_rope_lds,
vec4_uint k_buffer,
int warp_id,
int lane_id,
int k_row_stride,
int seqlen_kv_limit
) {
if constexpr (WARP_NUM == 8) {
int warp_id_row = warp_id & 3;
int warp_id_col = warp_id >> 2;
#pragma unroll
for (int load_loop = 0; load_loop < 2; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 64 + warp_id_row * 16 + (lane_id >> 2));
int k_col = warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * k_row_stride + k_col;
int lds_write_offset = load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32; // 8 * 4 * 16 * 32 * sizeof(fp16) = 32KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(k_rope_lds, k_buffer, lds_write_offset, 0, k_buffer_offset);
}
} else if constexpr (WARP_NUM == 4) {
constexpr int K_LOAD_REQUESTS = kBlockN / (16 * 2);
int warp_id_row = warp_id >> 1;
int warp_id_col = warp_id & 1;
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 32 + warp_id_row * 16 + (lane_id >> 2));
int k_col = warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * k_row_stride + k_col;
int lds_write_offset = load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32; // 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(k_rope_lds, k_buffer, lds_write_offset, 0, k_buffer_offset);
}
}
}
template<int kBlockN, int WARP_M, int WARP_N, int WARP_NUM, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_compute_fwd_qk_rope(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
union_vec4_f16x2<Element> q_regs[WARP_M / 16],
vec4_uint k_buffer,
Element* k_rope_lds,
int warp_id,
int lane_id,
int k_row_stride,
int seqlen_kv_limit) {
if constexpr (WARP_NUM == 8) {
// mla_prefetch_k_rope_to_lds<kBlockN, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
wait_buffer_data_arrived<true>(0);
int warp_id_col = warp_id >> 2;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = (n_loop >> 2) * 8 * 16 * 32 + (n_loop & 3) * 16 * 32 + warp_id_col * 4 * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(k_rope_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false/*sync*/>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(q_regs[m_idx].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(q_regs[m_idx].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__syncthreads();
} else if constexpr (WARP_NUM == 4) {
// mla_prefetch_k_rope_to_lds<kBlockN, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
wait_buffer_data_arrived<true>(0);
int warp_id_col = warp_id & 1;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = n_loop * 2 * 16 * 32 + warp_id_col * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(k_rope_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false/*sync*/>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(q_regs[m_idx].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(q_regs[m_idx].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__syncthreads();
}
}
template<int kBlockN, int WARP_NUM, typename Element>
__forceinline__ __device__ void mla_prefix_prefill_prefetch_k_nope_to_lds(
Element* v_lds,
vec4_uint v_buffer,
int warp_id,
int lane_id,
int v_row_stride,
int seqlen_kv_limit
) {
if constexpr (WARP_NUM == 8) {
constexpr int PREFETCH_K_BLOCKS = 2;
constexpr int K_LOAD_REQUESTS = kBlockN / (16 * 4); // 16 * 4 = 64
int warp_id_row = warp_id & 3;
int warp_id_col = warp_id >> 2;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH_K_BLOCKS; load_id += 2) {
#pragma unroll
for (int depth = 0; depth < 2; ++depth) {
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 64 + warp_id_row * 16 + (lane_id >> 2));
int k_col = (load_id + depth) * 64 + warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * v_row_stride + k_col;
int lds_write_offset = depth * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32; // 2 * 2 * 8 * 16 * 32 * sizeof(fp16) = 32KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, k_buffer_offset);
}
}
}
} else if constexpr (WARP_NUM == 4) {
__syncthreads();
constexpr int K_LOAD_REQUESTS = kBlockN / (16 * 2);
int warp_id_row = warp_id >> 1;
int warp_id_col = warp_id & 1;
int stage_id = 0;
constexpr int load_id = 0;
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 32 + warp_id_row * 16 + (lane_id >> 2));
int k_col = load_id * 64 + warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * v_row_stride + k_col;
int lds_write_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32; // 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, k_buffer_offset);
}
}
}
template<int kBlockN, int WARP_M, int WARP_N, int WARP_NUM, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_compute_fwd_qk_nope(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
union_vec4_f16x2<Element> qv_regs[WARP_M / 16][8],
vec4_uint v_buffer,
Element* v_lds,
vec4_uint k_buffer,
Element* k_rope_lds,
int warp_id,
int lane_id,
int v_row_stride,
int k_row_stride,
int seqlen_kv_limit) {
if constexpr (WARP_NUM == 8) {
constexpr int PREFETCH_K_BLOCKS = 2;
constexpr int K_LOAD_REQUESTS = kBlockN / (16 * 4);
int warp_id_row = warp_id & 3;
int warp_id_col = warp_id >> 2;
// prefetch_k_nope_to_lds<Element>(v_lds, v_buffer, warp_id, lane_id, v_row_stride, seqlen_kv_limit);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
inline_vgpr4_init_zero(s_reg[m_idx][n_loop]);
}
}
#pragma unroll
for (int load_id = 0; load_id < PREFETCH_K_BLOCKS; load_id += 2) {
#pragma unroll
for (int depth = 0; depth < 2; ++depth) {
wait_buffer_data_arrived<true>((2 - depth - 1) * K_LOAD_REQUESTS);
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = depth * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + (n_loop >> 2) * WARP_NUM * 16 * 32 + (n_loop & 3) * 16 * 32 + warp_id_col * 4 * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(v_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id + depth].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id + depth].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
}
}
asm volatile("s_barrier\n"); // 上面在读 lds, 下面在写 lds, 有数据冲突的隐患
// 提前预取 k_rope 部分的数据, 注意 lds 部分重叠
mla_prefix_prefill_prefetch_k_rope_to_lds<kBlockN, WARP_NUM, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
// 接着做剩下的内容
if constexpr (true) {
int stage_id = 0;
{
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 64 + warp_id_row * 16 + (lane_id >> 2));
int k_col = PREFETCH_K_BLOCKS * 64 + warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * v_row_stride + k_col;
int lds_write_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32;
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, k_buffer_offset);
}
}
stage_id ^= 1;
#pragma unroll
for (int load_id = PREFETCH_K_BLOCKS + 1; load_id < 8; load_id += 1) {
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 64 + warp_id_row * 16 + (lane_id >> 2));
int k_col = load_id * 64 + warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * v_row_stride + k_col;
int lds_write_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32;
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, k_buffer_offset);
}
wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
stage_id ^= 1;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + (n_loop >> 2) * WARP_NUM * 16 * 32 + (n_loop & 3) * 16 * 32 + warp_id_col * 4 * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(v_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
// rest
{
constexpr int load_id = 8;
wait_buffer_data_arrived<true>(0);
stage_id ^= 1;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = stage_id * K_LOAD_REQUESTS * 8 * 16 * 32 + (n_loop >> 2) * 8 * 16 * 32 + (n_loop & 3) * 16 * 32 + warp_id_col * 4 * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(v_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
} else if constexpr (WARP_NUM == 4) {
constexpr int K_LOAD_REQUESTS = kBlockN / (16 * 2);
int warp_id_row = warp_id >> 1;
int warp_id_col = warp_id & 1;
int stage_id = 0;
// mla_prefetch_k_nope_to_lds<kBlockN, Element>(v_lds, v_buffer, warp_id, lane_id, v_row_stride, seqlen_kv_limit);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
inline_vgpr4_init_zero(s_reg[m_idx][n_loop]);
}
}
stage_id ^= 1;
#pragma unroll
for (int load_id = 1; load_id < 8; ++load_id) {
#pragma unroll
for (int load_loop = 0; load_loop < K_LOAD_REQUESTS; ++load_loop) {
int k_row = min(seqlen_kv_limit - 1, load_loop * 32 + warp_id_row * 16 + (lane_id >> 2));
int k_col = load_id * 64 + warp_id_col * 32 + (lane_id & 3) * 8;
int k_buffer_offset = k_row * v_row_stride + k_col;
int lds_write_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + load_loop * WARP_NUM * 16 * 32 + warp_id * 16 * 32; // 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, k_buffer_offset);
}
wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
stage_id ^= 1;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + n_loop * 2 * 16 * 32 + warp_id_col * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(v_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false>(kBlockN / 16 - n_loop - 1);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__syncthreads();
}
// 预取 rope 部分的 K 数据, 注意 k_rope_lds 和 k_lds 的重叠关系
mla_prefix_prefill_prefetch_k_rope_to_lds<kBlockN, WARP_NUM, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
{
int load_id = 8;
wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
stage_id ^= 1;
union_vec4_f16x2<Element> k_regs[kBlockN / 16];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_wave_offset = stage_id * K_LOAD_REQUESTS * WARP_NUM * 16 * 32 + n_loop * 2 * 16 * 32 + warp_id_col * 16 * 32;
int lds_tx_offset = (lane_id & 15) * 32 + (lane_id >> 4) * 8;
inlineasm_ds_read_b128(reinterpret_cast<size_t>(v_lds + lds_wave_offset + lds_tx_offset), k_regs[n_loop]);
}
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
wait_lds_data_arrived<false>(kBlockN / 16 - n_loop - 1);
// 准备做 mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[0], k_regs[n_loop].f16x4[0], s_reg[m_idx][n_loop].f32);
s_reg[m_idx][n_loop].f32 = mmac<Element, ElementAccum>(qv_regs[m_idx][load_id - 1].f16x4[1], k_regs[n_loop].f16x4[1], s_reg[m_idx][n_loop].f32);
}
}
__syncthreads();
}
}
}
template<int kBlockN, int WARP_M, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_combine_s_reg_of_2waves(vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)], ElementAccum* s_reg_lds, int warp_id, int lane_id) {
constexpr bool IS_8_WAVES = WARP_NUM == 8;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int lds_write_offset = n_loop * WARP_NUM * (64 * 4) + warp_id * 64 * 4 + lane_id * 4;
*(vec4_fp32*)(s_reg_lds + lds_write_offset) = s_reg[m_idx][n_loop].f32;
}
__syncthreads();
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
int warp_id_symmetry = IS_8_WAVES
? ((warp_id >= 4) ? warp_id - 4: warp_id + 4)
: ((warp_id & 1) ? warp_id - 1: warp_id + 1);
int lds_load_offset = n_loop * WARP_NUM * (64 * 4) + warp_id_symmetry * 64 * 4 + lane_id * 4;
vec4_Accum<ElementAccum> symmetry_data = *(vec4_Accum<ElementAccum>*)(s_reg_lds + lds_load_offset);
s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[0], symmetry_data.u64[0]);
s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[1], symmetry_data.u64[1]);
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
template<int kBlockN, int WARP_M, int kHeadDimVSplit, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum scale_softmax_log2,
vec4_Accum<ElementAccum> acc_o[WARP_M / 16][kHeadDimVSplit / 16]) {
ElementAccum scores_max_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max_cur[m_idx] = scores_max[m_idx];
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
scores_max_cur[m_idx] = max(scores_max_cur[m_idx], s_reg[m_idx][n_loop].f32[vec_idx]);
}
}
scores_max_cur[m_idx] = max(scores_max_cur[m_idx], __shfl_xor_tmp(scores_max_cur[m_idx], 32));
scores_max_cur[m_idx] = max(scores_max_cur[m_idx], __shfl_xor_tmp(scores_max_cur[m_idx], 16));
}
// 做 softmax
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__float2 max_scaled;
max_scaled[0] = scores_max_cur[m_idx] == -INFINITY ? 0.f: -scores_max_cur[m_idx] * scale_softmax_log2;
max_scaled[1] = max_scaled[0];
__float2 scale_softmax_log2_pair;
scale_softmax_log2_pair[0] = scale_softmax_log2;
scale_softmax_log2_pair[1] = scale_softmax_log2;
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[0], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[1], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].f32[0] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[0]);
s_reg[m_idx][n_loop].f32[1] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[1]);
s_reg[m_idx][n_loop].f32[2] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[2]);
s_reg[m_idx][n_loop].f32[3] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[3]);
}
}
// 求和
ElementAccum scores_sum_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__float2 scores_sum_pair;
scores_sum_pair[0] = 0;
scores_sum_pair[1] = 0;
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[0]);
scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[1]);
}
scores_sum_cur[m_idx] = scores_sum_pair[0] + scores_sum_pair[1];
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor(scores_sum_cur[m_idx], 32);
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor(scores_sum_cur[m_idx], 16);
}
// 放缩
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__float2 scores_scale;
scores_scale[0] = __llvm_exp2_f32(__llvm_fma_f32(scores_max[m_idx], scale_softmax_log2, /*max_scaled[0]*/-scores_max_cur[m_idx] * scale_softmax_log2));
scores_scale[1] = scores_scale[0];
scores_sum[m_idx] *= scores_scale[0];
#pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], scores_scale);
acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], scores_scale);
}
}
// update max/sum
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_sum[m_idx] += scores_sum_cur[m_idx];
scores_max[m_idx] = scores_max_cur[m_idx];
}
}
template<int kBlockN, int WARP_M, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_apply_mask(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
int lane_id,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q) {
constexpr bool IS_8_WAVES = WARP_NUM == 8;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int col_idx_limit_right = max_seqlen_k;
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_offset_ + n_loop * 16 + vec_idx * 4 + (lane_id >> 4);
s_reg[m_idx][n_loop].f32[vec_idx] = (col_idx >= col_idx_limit_right) ? -INFINITY: s_reg[m_idx][n_loop].f32[vec_idx];
}
}
}
}
template<int kBlockN, int WARP_M, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_apply_mtp_mask(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
int lane_id,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q,
const int ngroups,
const int mtp) {
constexpr bool IS_8_WAVES = WARP_NUM == 8;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int row_idx = row_idx_offset_ + (IS_8_WAVES ? m_idx * 64: m_idx * WARP_M) + (lane_id & 15);
int row_in_mtp = row_idx / ngroups;
int col_idx_limit_right = min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_offset_ + n_loop * 16 + vec_idx * 4 + (lane_id >> 4);
s_reg[m_idx][n_loop].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: s_reg[m_idx][n_loop].f32[vec_idx];
}
}
}
}
template<int kBlockN, int WARP_M, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_apply_causal_mask(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][(kBlockN / 16)],
int lane_id,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q) {
constexpr bool IS_8_WAVES = WARP_NUM == 8;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
int row_idx = row_idx_offset_ + (IS_8_WAVES ? m_idx * 64: m_idx * WARP_M) + (lane_id & 15);
int col_idx_limit_right = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q);
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_offset_ + n_loop * 16 + vec_idx * 4 + (lane_id >> 4);
s_reg[m_idx][n_loop].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: s_reg[m_idx][n_loop].f32[vec_idx];
}
}
}
}
template<int kBlockN, int WARP_M, typename ElementAccum, typename Element>
__forceinline__ __device__ void mla_prefix_prefill_cvt_dtype(
vec4_Accum<ElementAccum> s_reg[WARP_M / 16][kBlockN / 16],
union_vec2_f16x2<Element> p_reg[WARP_M / 16][kBlockN / 16]) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
p_reg[m_idx][n_loop].f16x2[vec_idx] = DownCastPair<ElementAccum, Element>(s_reg[m_idx][n_loop].f32x2[vec_idx]);
}
#else
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
p_reg[m_idx][n_loop].f16[vec_idx] = DownCast<ElementAccum, Element, false>(s_reg[m_idx][n_loop].f32[vec_idx]);
}
#endif
}
}
}
template<int PREFETCH_V_BLOCKS, int WARP_NUM, typename Element>
__forceinline__ __device__ void mla_prefix_prefill_prefetch_v_to_lds(
vec4_uint v_buffer,
Element* v_lds,
int v_row_stride,
int warp_id,
int lane_id,
int seqlen_kv_limit
) {
if constexpr (WARP_NUM == 8) {
#pragma unroll
for (int n_loop = 0; n_loop < PREFETCH_V_BLOCKS; n_loop += 2) {
#pragma unroll
for (int depth = 0; depth < 2; ++depth) {
#pragma unroll
for (int load_loop = 0; load_loop < 2; ++load_loop) {
int v_row = min(seqlen_kv_limit - 1, (n_loop + depth) * 16 + (lane_id >> 2));
int v_col = load_loop * 8 * 32 + warp_id * 32 + (lane_id & 3) * 8;
int v_buffer_offset = v_row * v_row_stride + v_col;
int lds_write_offset = depth * 2 * WARP_NUM * 512 + load_loop * WARP_NUM * 512 + warp_id * 512; // 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, v_buffer_offset);
}
}
}
} else if constexpr (WARP_NUM == 4) {
__syncthreads();
constexpr int V_LOAD_REQUESTS = PREFETCH_V_BLOCKS; // union
int warp_id_col = warp_id & 1;
int stage_id = 0;
constexpr int n_loop = 0;
#pragma unroll
for (int load_loop = 0; load_loop < V_LOAD_REQUESTS; ++load_loop) {
int v_row = min(seqlen_kv_limit - 1, n_loop * 16 + (lane_id >> 2));
int v_col = load_loop * 4 * 32 + warp_id * 32 + (lane_id & 3) * 8;
int v_buffer_offset = v_row * v_row_stride + v_col;
int lds_write_offset = stage_id * V_LOAD_REQUESTS * WARP_NUM * 512 + load_loop * WARP_NUM * 512 + warp_id * 512; // 4 * 4 * 512 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, v_buffer_offset);
}
}
}
template<bool PREFETCH_K, int PREFETCH_V_BLOCKS, int kBlockN, int WARP_M, int WARP_NUM, int kHeadDimVSplit, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_compute_fwd_pv(
vec4_Accum<ElementAccum> acc_o[WARP_M / 16][kHeadDimVSplit / 16],
union_vec2_f16x2<Element> p_reg[WARP_M / 16][kBlockN / 16],
vec4_uint v_buffer,
Element* v_lds,
int warp_id,
int lane_id,
int v_row_stride,
int seqlen_kv_limit,
int v_buffer_offset) {
if constexpr (WARP_NUM == 8) {
wait_buffer_data_arrived<true>(0);
constexpr int V_LOAD_REQUESTS = 2;
int warp_id_col = warp_id >> 2;
#pragma unroll
for (int n_loop = 0; n_loop < PREFETCH_V_BLOCKS; n_loop += 2) {
#pragma unroll
for (int depth = 0; depth < 2; ++depth) {
// lds -> vgprs
union_vec4_f16x2<Element> v_regs[kHeadDimVSplit / 32];
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
int v_load_base_offset = depth * V_LOAD_REQUESTS * WARP_NUM * 512 + warp_id_col * 8 * 512 + v_tile * 512;
#pragma unroll
for (int i = 0; i < 2; ++i) {
int v_load_offset = v_load_base_offset + i * 8 * 32 + (lane_id >> 4) * 32 + (lane_id & 15) * 2;
inline_ds_read2_b32_no_wait_bytes(reinterpret_cast<size_t>(v_lds + v_load_offset), v_regs[v_tile].f16x4[i], 64);
}
}
// pv mmac
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
wait_lds_data_arrived<false>((kHeadDimVSplit / 32 - v_tile - 1) * 2);
// v interleave into vgprs
union_vec4_f16x2<Element> v_composed;
v_composed.f16x4[0] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 0], v_regs[v_tile].f16[1 * 2 + 0], v_regs[v_tile].f16[2 * 2 + 0], v_regs[v_tile].f16[3 * 2 + 0]);
v_composed.f16x4[1] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 1], v_regs[v_tile].f16[1 * 2 + 1], v_regs[v_tile].f16[2 * 2 + 1], v_regs[v_tile].f16[3 * 2 + 1]);
// pv mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
acc_o[m_idx][v_tile * 2 + 0].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop + depth].f16x4, v_composed.f16x4[0], acc_o[m_idx][v_tile * 2 + 0].f32);
acc_o[m_idx][v_tile * 2 + 1].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop + depth].f16x4, v_composed.f16x4[1], acc_o[m_idx][v_tile * 2 + 1].f32);
}
}
}
}
asm volatile("s_barrier\n"); // 上面在读 lds, 下面在写 lds, 有数据冲突的隐患
// 做没预取的部分, 还需要重新取数据
if constexpr (true) {
int stage_id = 0;
{
#pragma unroll
for (int load_loop = 0; load_loop < V_LOAD_REQUESTS; ++load_loop) {
int v_row = min(seqlen_kv_limit - 1, PREFETCH_V_BLOCKS * 16 + (lane_id >> 2));
int v_col = load_loop * 8 * 32 + warp_id * 32 + (lane_id & 3) * 8;
int v_buffer_offset = v_row * v_row_stride + v_col;
int lds_write_offset = stage_id * V_LOAD_REQUESTS * WARP_NUM * 512 + load_loop * WARP_NUM * 512 + warp_id * 512; // 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, v_buffer_offset);
}
}
stage_id ^= 1;
#pragma unroll
for (int n_loop = PREFETCH_V_BLOCKS + 1; n_loop < kBlockN / 16; n_loop += 1) {
#pragma unroll
for (int load_loop = 0; load_loop < V_LOAD_REQUESTS; ++load_loop) {
int v_row = min(seqlen_kv_limit - 1, n_loop * 16 + (lane_id >> 2));
int v_col = load_loop * 8 * 32 + warp_id * 32 + (lane_id & 3) * 8;
int v_buffer_offset = v_row * v_row_stride + v_col;
int lds_write_offset = stage_id * V_LOAD_REQUESTS * WARP_NUM * 512 + load_loop * WARP_NUM * 512 + warp_id * 512; // 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, v_buffer_offset);
}
stage_id ^= 1;
wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
// lds -> vgprs
union_vec4_f16x2<Element> v_regs[kHeadDimVSplit / 32];
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
int v_load_base_offset = stage_id * V_LOAD_REQUESTS * WARP_NUM * 512 + warp_id_col * 8 * 512 + v_tile * 512;
#pragma unroll
for (int i = 0; i < 2; ++i) {
int v_load_offset = v_load_base_offset + i * 8 * 32 + (lane_id >> 4) * 32 + (lane_id & 15) * 2;
inline_ds_read2_b32_no_wait_bytes(reinterpret_cast<size_t>(v_lds + v_load_offset), v_regs[v_tile].f16x4[i], 64);
}
}
// pv mmac
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
wait_lds_data_arrived<false>((kHeadDimVSplit / 32 - v_tile - 1) * 2);
// v interleave into vgprs
union_vec4_f16x2<Element> v_composed;
v_composed.f16x4[0] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 0], v_regs[v_tile].f16[1 * 2 + 0], v_regs[v_tile].f16[2 * 2 + 0], v_regs[v_tile].f16[3 * 2 + 0]);
v_composed.f16x4[1] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 1], v_regs[v_tile].f16[1 * 2 + 1], v_regs[v_tile].f16[2 * 2 + 1], v_regs[v_tile].f16[3 * 2 + 1]);
// pv mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
acc_o[m_idx][v_tile * 2 + 0].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[0], acc_o[m_idx][v_tile * 2 + 0].f32);
acc_o[m_idx][v_tile * 2 + 1].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[1], acc_o[m_idx][v_tile * 2 + 1].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
// rest
{
constexpr int n_loop = kBlockN / 16;
stage_id ^= 1;
wait_buffer_data_arrived<true>(0);
// lds -> vgprs
union_vec4_f16x2<Element> v_regs[kHeadDimVSplit / 32];
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
int v_load_base_offset = stage_id * V_LOAD_REQUESTS * WARP_NUM * 512 + warp_id_col * 8 * 512 + v_tile * 512;
#pragma unroll
for (int i = 0; i < 2; ++i) {
int v_load_offset = v_load_base_offset + i * 8 * 32 + (lane_id >> 4) * 32 + (lane_id & 15) * 2;
inline_ds_read2_b32_no_wait_bytes(reinterpret_cast<size_t>(v_lds + v_load_offset), v_regs[v_tile].f16x4[i], 64);
}
}
// pv mmac
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
wait_lds_data_arrived<false>((kHeadDimVSplit / 32 - v_tile - 1) * 2);
// v interleave into vgprs
union_vec4_f16x2<Element> v_composed;
v_composed.f16x4[0] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 0], v_regs[v_tile].f16[1 * 2 + 0], v_regs[v_tile].f16[2 * 2 + 0], v_regs[v_tile].f16[3 * 2 + 0]);
v_composed.f16x4[1] = make_vec4_f16<Element>(v_regs[v_tile].f16[0 * 2 + 1], v_regs[v_tile].f16[1 * 2 + 1], v_regs[v_tile].f16[2 * 2 + 1], v_regs[v_tile].f16[3 * 2 + 1]);
// pv mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
acc_o[m_idx][v_tile * 2 + 0].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[0], acc_o[m_idx][v_tile * 2 + 0].f32);
acc_o[m_idx][v_tile * 2 + 1].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[1], acc_o[m_idx][v_tile * 2 + 1].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
} else if constexpr (WARP_NUM == 4) {
constexpr int V_LOAD_REQUESTS = 4;
// mla_prefetch_v_to_lds<V_LOAD_REQUESTS, Element>(v_buffer, v_lds, v_row_stride, warp_id, lane_id, seqlen_kv_limit);
int stage_id = 1;
int warp_id_col = warp_id & 1;
#pragma unroll
for (int n_loop = 1; n_loop < kBlockN / 16; ++n_loop) {
#pragma unroll
for (int load_loop = 0; load_loop < V_LOAD_REQUESTS; ++load_loop) {
int v_row = min(seqlen_kv_limit - 1, n_loop * 16 + (lane_id >> 2));
int v_col = load_loop * 4 * 32 + warp_id * 32 + (lane_id & 3) * 8;
int v_buffer_offset = v_row * v_row_stride + v_col;
int lds_write_offset = stage_id * 4 * 4 * 512 + load_loop * 4 * 512 + warp_id * 512; // 4 * 4 * 512 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds<Element, 1>(v_lds, v_buffer, lds_write_offset, 0, v_buffer_offset);
}
wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
stage_id ^= 1;
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
// lds -> vgprs
union_vec4_f16x2<Element> v_regs;
int v_load_base_offset = stage_id * 4 * 4 * 512 + warp_id_col * 8 * 512 + v_tile * 512;
#pragma unroll
for (int i = 0; i < 4; ++i) {
int v_load_offset = v_load_base_offset + i * 4 * 32 + (lane_id >> 4) * 32 + (lane_id & 15) * 2;
v_regs.f16x2[i] = *(vec2_Element<Element>*)(v_lds + v_load_offset);
}
// v regs interleave
union_vec4_f16x2<Element> v_composed;
v_composed.f16x4[0] = make_vec4_f16<Element>(v_regs.f16[0 * 2 + 0], v_regs.f16[1 * 2 + 0], v_regs.f16[2 * 2 + 0], v_regs.f16[3 * 2 + 0]);
v_composed.f16x4[1] = make_vec4_f16<Element>(v_regs.f16[0 * 2 + 1], v_regs.f16[1 * 2 + 1], v_regs.f16[2 * 2 + 1], v_regs.f16[3 * 2 + 1]);
// pv mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
acc_o[m_idx][v_tile * 2 + 0].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[0], acc_o[m_idx][v_tile * 2 + 0].f32);
acc_o[m_idx][v_tile * 2 + 1].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[1], acc_o[m_idx][v_tile * 2 + 1].f32);
}
}
__syncthreads();
}
{
if constexpr (PREFETCH_K) {
vec4_uint k_rope_buffer = v_buffer;
*(int64_t*)&k_rope_buffer += v_buffer_offset;
mla_prefix_prefill_prefetch_k_nope_to_lds<kBlockN, WARP_NUM, Element>(v_lds, k_rope_buffer, warp_id, lane_id, v_row_stride, seqlen_kv_limit - kBlockN);
wait_buffer_data_arrived<true>(kBlockN / (16 * 2));
} else {
wait_buffer_data_arrived<true>(0);
}
constexpr int n_loop = kBlockN / 16;
stage_id ^= 1;
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
// lds -> vgprs
union_vec4_f16x2<Element> v_regs;
int v_load_base_offset = stage_id * 4 * 4 * 512 + warp_id_col * 8 * 512 + v_tile * 512;
#pragma unroll
for (int i = 0; i < 4; ++i) {
int v_load_offset = v_load_base_offset + i * 4 * 32 + (lane_id >> 4) * 32 + (lane_id & 15) * 2;
v_regs.f16x2[i] = *(vec2_Element<Element>*)(v_lds + v_load_offset);
}
// v vgpr interleave
union_vec4_f16x2<Element> v_composed;
v_composed.f16x4[0] = make_vec4_f16<Element>(v_regs.f16[0 * 2 + 0], v_regs.f16[1 * 2 + 0], v_regs.f16[2 * 2 + 0], v_regs.f16[3 * 2 + 0]);
v_composed.f16x4[1] = make_vec4_f16<Element>(v_regs.f16[0 * 2 + 1], v_regs.f16[1 * 2 + 1], v_regs.f16[2 * 2 + 1], v_regs.f16[3 * 2 + 1]);
// pv mmac
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
acc_o[m_idx][v_tile * 2 + 0].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[0], acc_o[m_idx][v_tile * 2 + 0].f32);
acc_o[m_idx][v_tile * 2 + 1].f32 = mmac<Element, ElementAccum>(p_reg[m_idx][n_loop - 1].f16x4, v_composed.f16x4[1], acc_o[m_idx][v_tile * 2 + 1].f32);
}
}
__syncthreads();
}
}
}
template<int kBlockM, int WARP_M, int WARP_NUM, int kHeadDimVSplit, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_rescale_acc_o(
vec4_Accum<ElementAccum> acc_o[WARP_M / 16][kHeadDimVSplit / 16],
ElementAccum* scores_max_ptr,
ElementAccum* scores_sum_ptr,
ElementAccum* softmax_lse_ptr,
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
ElementAccum scale_softmax,
int64_t row_offset_lse,
int m_block,
int warp_id,
int warp_id_row,
int lane_id,
int actual_seqlen_q
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
ElementAccum sum = scores_sum[m_idx];
ElementAccum lse = (sum == 0.f || sum != sum) ? INFINITY: __llvm_fma_f32(scores_max[m_idx], scale_softmax, __logf(sum));
if constexpr (WARP_NUM == 8) {
if (lane_id < 16 and warp_id < 4) {
int lse_offset = warp_id * 16 + lane_id;
if (lse_offset < actual_seqlen_q - m_block * kBlockM) {
scores_max_ptr[row_offset_lse + lse_offset] = scores_max[m_idx] * scale_softmax;
scores_sum_ptr[row_offset_lse + lse_offset] = scores_sum[m_idx];
softmax_lse_ptr[row_offset_lse + lse_offset] = lse;
}
}
} else if constexpr (WARP_NUM == 4) {
if (lane_id < 16 and ((warp_id & 1) == 0)) {
int lse_offset = m_idx * WARP_M + warp_id_row * 16 + lane_id;
if (lse_offset < actual_seqlen_q - m_block * kBlockM) {
scores_max_ptr[row_offset_lse + lse_offset] = scores_max[m_idx] * scale_softmax;
scores_sum_ptr[row_offset_lse + lse_offset] = scores_sum[m_idx];
softmax_lse_ptr[row_offset_lse + lse_offset] = lse;
}
}
}
// 放缩 acc_o
__float2 inv_sum;
inv_sum[0] = 1.0f / sum;
inv_sum[1] = inv_sum[0];
#pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit / 16; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], inv_sum);
acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], inv_sum);
}
}
}
template<int kBlockM, int WARP_M, int WARP_NUM, int kHeadDimVSplit, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_prefix_prefill_store_output(
vec4_Accum<ElementAccum> acc_o[WARP_M / 16][kHeadDimVSplit / 16],
void* __restrict__ o_raw_ptr,
int64_t row_offset_o,
int m_block,
int warp_id_row,
int warp_id_col,
int lane_id,
int o_row_stride,
int actual_seqlen_q
) {
Element *o_ptr = reinterpret_cast<Element*>(o_raw_ptr) + row_offset_o;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int v_tile = 0; v_tile < kHeadDimVSplit / 32; ++v_tile) {
int row_idx = (WARP_NUM == 8 ? m_idx * 64: m_idx * WARP_M) + warp_id_row * 16 + (lane_id & 15);
if (m_block * kBlockM + row_idx < actual_seqlen_q) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
vec2_Element<Element> data;
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
data[mmac_id] = DownCast<ElementAccum, Element, true>(acc_o[m_idx][v_tile * 2 + mmac_id].f32[vec_idx]);
}
#else
data = DownCastPairNoPack<ElementAccum, Element>(acc_o[m_idx][v_tile * 2 + 0].f32[vec_idx], acc_o[m_idx][v_tile * 2 + 1].f32[vec_idx]);
#endif
int col_idx = warp_id_col * 256 + v_tile * 32 + vec_idx * 8 + (lane_id >> 4) * 2;
int64_t write_offset = row_idx * int64_t(o_row_stride) + col_idx;
*(vec2_Element<Element>*)(o_ptr + write_offset) = data;
}
}
}
}
}
\ No newline at end of file
#include "mla_pv_gemm_utils.h"
template<int K_LOOP_COUNT, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_pv_gemm_prefetch_k(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert(kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert(kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in mla_pv_gemm_prefetch_k must be WARP_N * 32");
union_vec2_f16x2<Element> v_reg[STAGES * PV_K_WARP_COUNT * PV_N_WARP_COUNT][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// global -> lds, seqlen 方向的坐标
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + warp_id * STAGES * WARP_K * kBlockN;
int stage_id = (STAGES == 2) ? 1: 0;
constexpr int N_LOOP_START = (STAGES == 2) ? 1: 0;
for (int n_loop = N_LOOP_START; n_loop < K_LOOP_COUNT; ++n_loop) {
int v_block_buffer_load_global_offset = warp_id * WARP_K * kvcache_seqlen_stride + n_loop * kBlockN;
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + warp_id) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
if constexpr (STAGES == 2) stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4/*4 bytes per dword*/;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0/*vec_idx*/],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2/*vec_idx*/].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if constexpr (STAGES == 2) {
int n_loop = K_LOOP_COUNT - 1;
stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<0>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0/*vec_idx*/],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2/*vec_idx*/].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0/*vec_idx*/],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2/*vec_idx*/].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment