Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
......@@ -36,7 +36,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_3stage(
constexpr int ELEMENT_BYTES = sizeof(Element);
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg);
} else {
......
......@@ -28,7 +28,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_tile16x32(
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
#if defined(__gfx936__) || defined(__gfx938__) // >= bmz
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
......
......@@ -218,7 +218,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
......@@ -227,7 +227,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -254,7 +254,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
......@@ -262,7 +262,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -362,16 +362,15 @@ inline __device__ void kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M / 32) *
// min tile is 32 * 32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32(
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
......@@ -448,10 +447,10 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
......@@ -503,8 +502,8 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
#pragma unroll
for(int warp_loop = 1; warp_loop < WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
......@@ -528,8 +527,8 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
}
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......@@ -558,7 +557,7 @@ inline __device__ void kvcache_convert_pk_type(union_vec2_f16x2<Element> p_reg[(
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) || defined(__gfx92a__)
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
......
#pragma once
#include "f16_mla_tp8_pv_gemm_utils_gfx938.h"
template<int K_LOOP_COUNT, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void f16_mla_tp8_pv_gemm_gfx938(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=0) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32");
static_assert (M_WARP_COUNT == 1, "for gfx938, only WARP_M = 32 is supported yet!");
static_assert (PV_N_WARP_COUNT == 1, "for gfx938, only WARP_N = 32 is supported yet!");
static_assert (PV_K_WARP_COUNT == 1, "for gfx938, only WARP_K = 32 is supported yet!");
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
// 准备寄存器, 每次加载 32x32 的 half 用于 mmac 计算, 每个线程持有 16 个 half, 因此是 8 * 2, 一列有 8 个 half, 有两列
union_vec4_f16x2<Element> v_reg[1 * PV_K_WARP_COUNT * PV_N_WARP_COUNT * 2];
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = kvcache_seqlen_stride; // stride
// 防止与多 wave reduce max 需要的 lds 冲突
__syncthreads();
int stage_id = (STAGES == 2) ? 1: 0;
// 一次加载多批数据
constexpr int N_LOOP_STEP = (STAGES == 2) ? 2: 1;
constexpr int N_LOOP_START = (STAGES == 2) ? K_LOOP_COUNT - N_LOOP_STEP * 2: K_LOOP_COUNT - 1;
constexpr int N_LOOP_END = 0;
for (int n_loop = N_LOOP_START; n_loop >= N_LOOP_END; n_loop -= N_LOOP_STEP) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) {
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int v_mls_warp_global_offset = (n_loop + prefetch_id) * kBlockN * sizeof(Element);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int v_mls_lds_warp_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32) * sizeof(Element);
// 计算当前 wave 读取数据的起始偏移字节数
int v_mls_loop_global_offset; // = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset = real_mls_warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
// 等待 MLS 数据回来
if constexpr (N_LOOP_STEP == 2) {
buffer_load_lds_dwordx1_wait_nosync<3 * V_LOAD_REQUESTS>();
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// 切换到 load 轮次
if constexpr (STAGES == 2) {
stage_id ^= 1;
}
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false); // hint: multiple prefetching can be applied here
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = (STAGES == 2) ? n_loop + 2: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
// ============================================================================================================
// 处理预取的第二段数据
if constexpr (N_LOOP_STEP == 2) {
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
__builtin_amdgcn_sched_barrier(0);
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = (STAGES == 2) ? n_loop + 3: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
}
}
if constexpr (STAGES == 2) {
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<1 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int n_loop = N_LOOP_END - N_LOOP_STEP;
// 切换
stage_id ^= 1;
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = n_loop + N_LOOP_STEP;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
// ============================================================================================================
// 处理预取的第二段数据
if constexpr (N_LOOP_STEP == 2) {
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
__builtin_amdgcn_sched_barrier(0);
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = n_loop + N_LOOP_STEP + 1;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void f16_mla_tp8_prefetch_v_to_lds_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=0) {
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
constexpr int N_LOOP_STEP = 2;
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = kvcache_seqlen_stride; // stride
// 从倒数第 2 个 block 开始读取
int n_loop = kHeadDim / kBlockN - N_LOOP_STEP;
#pragma unroll
for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) {
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int v_mls_warp_global_offset = (n_loop + prefetch_id) * kBlockN * sizeof(Element);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int v_mls_lds_warp_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32) * sizeof(Element);
// 计算当前 wave 读取数据的起始偏移字节数
int v_mls_loop_global_offset;// = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset = real_mls_warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
#pragma once
#include "f16_mla_tp8_pv_gemm_utils_gfx938.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void f16_mla_tp8_qk_gemm_gfx938(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset=0) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
constexpr int K_LOAD_REQUESTS = (WARP_N / 32) * (kBlockK / 32);
// 分配 k 计算 mmac 需要的寄存器资源
// 一次加载 32x32 个 half, 每个线程持有 16 个 half
union_vec4_f16x2<Element> k_reg[1 * (WARP_N * kBlockK) / (32 * 32) * 2];
// 初始化 s
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
// 准备 MLS resource 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = kcache_seqlen_stride;
int stage_id = 0;
constexpr int K_LOOP_START = (STAGES == 2) ? 2: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop += 2) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < 2; ++prefetch_id) {
// 计算当前 wave 写到 lds 的起始地址
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * K_LOAD_REQUESTS * (32 * 32);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int k_mls_loop_global_offset = (k_loop + prefetch_id) * kBlockK * sizeof(Element);
// 计算当前 wave 从 global 读取数据的起始地址
int k_mls_warp_global_offset; // = warp_id * WARP_N * kcache_seqlen_stride * sizeof(Element);
if constexpr (true) {
int nm_filter_max = warp_id * WARP_N + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset = real_mls_warp_id * WARP_N * kcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_N + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
}
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<3 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) stage_id ^= 1;
// 加载上一次 MLS 写到 lds 的数据到寄存器
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 2: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
// =============================================================================================================
{
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// 加载上一次 MLS 写到 lds 的数据到寄存器
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (STAGES == 2) {
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<1>();
__builtin_amdgcn_sched_barrier(0);
// 切换到上一次 lds 被写入的轮次
stage_id ^= 1;
// 从 lds 加载最后一部分数据
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 2;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
// ==========================================================================
{
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0>();
__builtin_amdgcn_sched_barrier(0);
// 从 lds 加载最后一部分数据
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
f16_mla_tp8_prefetch_v_to_lds_gfx938<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(v_addr, v_lds, warp_id, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
template<int kBlockK, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void f16_mla_tp8_prefetch_k_to_lds_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=0) {
// 准备 MLS 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = kvcache_seqlen_stride;
// pingpong buffer 的第一阶段
int stage_id = 0;
// kHeadDim 方向上的第几个 32x32 块
int k_loop = 0;
#pragma unroll
for (int prefetch_id = 0; prefetch_id < 2; ++prefetch_id) {
// 计算当前 wave 写到 lds 的起始地址
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int k_mls_loop_global_offset = (k_loop + prefetch_id) * kBlockK * sizeof(Element);
// 计算当前 wave 从 global 读取数据的起始地址
int k_mls_warp_global_offset; // = warp_id * WARP_N * kvcache_seqlen_stride;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_N + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset = real_mls_warp_id * WARP_N * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_N + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
......@@ -104,8 +104,8 @@ __forceinline__ __device__ void fp8_mla_acco_reduce_tile16x32(
data.f32[1] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 1 * 64];
data.f32[2] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 2 * 64];
data.f32[3] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 3 * 64];
acc_o[k_loop + 0][min_tile_n * 2].u64[0] = hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], data.u64[0]);
acc_o[k_loop + 0][min_tile_n * 2].u64[1] = hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], data.u64[1]);
acc_o[k_loop + 0][min_tile_n * 2].u64[0] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], data.u64[0]);
acc_o[k_loop + 0][min_tile_n * 2].u64[1] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], data.u64[1]);
}
}
__syncthreads();
......
......@@ -22,9 +22,9 @@ __forceinline__ __device__ void fp8_mla_epilugue_rescale_acco_gfx938(
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
......
......@@ -75,8 +75,8 @@ inline __device__ void fp8_mla_apply_descale_gfx938(DataType tensor[M_WARP_COUNT
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] = hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] = hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
}
}
}
......
......@@ -88,16 +88,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[1] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[2] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[3] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[0][0], v_f32x2[0][1]);
v_f16x8.f16x2[1] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[1][0], v_f32x2[1][1]);
v_f16x8.f16x2[2] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[2][0], v_f32x2[2][1]);
v_f16x8.f16x2[3] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[3][0], v_f32x2[3][1]);
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
......@@ -151,16 +151,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[1] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[2] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[3] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[0][0], v_f32x2[0][1]);
v_f16x8.f16x2[1] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[1][0], v_f32x2[1][1]);
v_f16x8.f16x2[2] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[2][0], v_f32x2[2][1]);
v_f16x8.f16x2[3] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[3][0], v_f32x2[3][1]);
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
......
......@@ -34,9 +34,9 @@ __forceinline__ __device__ void prefill_mla_epilugue_rescale_acco(
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (kHeadDimV / kBlockK); ++pv_n_loop) {
const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair
);
......@@ -115,8 +115,35 @@ __forceinline__ __device__ void prefill_mla_epilogue_store_output(
int pv_lane_head_dim_idx = lane_id >> 4;
if constexpr (Is_Interleaved) {
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
union_vec2_f16x2<Element> acc_o_fp16[(kHeadDimV / kBlockK) * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32)][2 * M_MMAC_COUNT];
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id;
if constexpr (M_MMAC_COUNT == 2)
mmac_id = min_tile_m + min_tile_n * 2;
else
mmac_id = min_tile_n;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
acc_o_fp16[k_loop][mmac_id].f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[k_loop][mmac_id].f32x2[vec_index]);
}
ds_mpermute_kdim_for_mmac(acc_o_fp16[k_loop][mmac_id].f32);
}
}
}
#endif
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
flash::wait_lds_data_arrived<false>((kHeadDimV / kBlockK - k_loop - 1) * 2 * 2);
#endif
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++warp_m_idx) {
#pragma unroll
......@@ -137,19 +164,29 @@ __forceinline__ __device__ void prefill_mla_epilogue_store_output(
// prepare for store
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = acc_o_fp16[k_loop][mmac_id];
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = acc_o_fp16[k_loop][mmac_id];
}
#else
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
// convert float -> bf16/fp16
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
#endif
}
}
}
......
......@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
// DS
......@@ -136,10 +133,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
}
stage_id ^= 1;
......@@ -200,10 +194,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
}
lds_stage_id ^= 1;
......@@ -309,4 +300,4 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
}
}
}
\ No newline at end of file
......@@ -29,9 +29,6 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512(
int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
......@@ -106,10 +106,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
if (k_even) {
k_stage_id ^= 1;
......@@ -122,10 +119,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
......@@ -317,11 +311,11 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
}
if constexpr (STAGES == 2) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset);
#else
#endif
}
} // qk_gemm
} // qk_gemm
\ No newline at end of file
......@@ -36,10 +36,7 @@ __forceinline__ __device__ void prefetch_q_to_lds_mls_ds_576_512(
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
}
}
......@@ -71,9 +68,6 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512(
}
int lds_offset = (stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
\ No newline at end of file
......@@ -13,13 +13,13 @@ struct PrefillMlaAllreduce {
DataType res;
if constexpr (std::is_same<DataType, union_vec2_fp32>::value) {
if constexpr (std::is_same<Operator, SumOp<float> >::value) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = hcu_pk_add_f32(x.u64, res.u64);
x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 16);
res.u64 = hcu_pk_add_f32(res.u64, x.u64);
res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
......@@ -100,7 +100,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
summary[m_idx * 2].u64 = 0x0;
} else {
......@@ -113,7 +113,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
if constexpr (M_MMAC_COUNT == 2){
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -146,7 +146,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / (16 * M_MMAC_COUNT)); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
} else {
......@@ -159,7 +159,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -273,15 +273,14 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / (16 * M_MMAC_
mmac_id = min_tile_n;
}
int qk_tile_id = mi + ni * (WARP_M / (16 * M_MMAC_COUNT));
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
}
......@@ -340,10 +339,10 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
} else {
mmac_id = min_tile_n;
}
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][mmac_id].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
......@@ -372,9 +371,9 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum_cur);
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
scores_sum[mi].u64 = hcu_pk_add_f32(
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......@@ -390,7 +389,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
}
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__))
if constexpr (M_MMAC_COUNT == 2) {
inlineasm_fa_v_mov_b64(
scores_max[mi].u64,
......@@ -423,7 +422,7 @@ inline __device__ void prefill_mla_convert_pk_type(union_vec2_f16x2<Element> p_r
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (M_MMAC_COUNT == 2) {
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
......
......@@ -33,10 +33,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
// matrix load
__builtin_amdgcn_sched_barrier(0);
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
......@@ -63,10 +60,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
int q_warp_offset = (LOAD * WARP_NUM + real_warp_id) * 32;
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
__builtin_amdgcn_sched_barrier(0);
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
// continue from MID
......@@ -89,4 +83,4 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
// wait all data written to registers
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
}
\ No newline at end of file
......@@ -20,9 +20,9 @@ __forceinline__ __device__ void mla_epilugue_rescale_acco(
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
......
......@@ -15,12 +15,7 @@ __forceinline__ __device__ void mla_epilogue_store_max_sum_tile16x32(
int headdim_split_id,
int seqlen_q_limit
) {
#ifdef FA_DEBUG_SUM_MAX
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
if constexpr (Split) {
if (headdim_split_id == 0) { // 因为 split-D 使用同样的 QK, 计算得到同样的 scores_sum/scores_max 会写多遍, 可能会有数据冲突, 所以强制只写一遍
if (thread_id < 16) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment