Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "utils.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int stage_id, typename Element, bool Is_even_MN, int STAGES=2>
__forceinline__ __device__ void prefetch_v_to_lds(
vec4_uint gV,
Element* v_lds,
int WARP_ID,
int seqlen_v_stride,
int max_seq_kv_offset=-1) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr bool IS_HEADDIM_128 = (kHeadDim == 128 and kHeadDimV == 128) or (kHeadDim == 64 and kHeadDimV == 64);
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
#if defined(USE_BUFFER_LOAD_DWORDX4)
// 对于 headdim 128 而言, 2 组 32x32 可以写成 4 个 warp 分别读取 8 个 half, 即 4x64x8, 可以使用 buffer_load_dwordx4
// 对于其他 headdim 而言, 暂不做那么激进的优化, 只有 1 组 32x32, 最多用 buffer_load_dwordx2
constexpr int WARP_K = 32;
constexpr int READ_ONCE_LINES = IS_HEADDIM_128 ? 16: 8; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = IS_HEADDIM_128 ? (2 * kBlockN * WARP_K/*对于非 headdim128 结尾的选项这里需要填 1, 一次只取 1 个 32x32 块*/) / READ_ONCE_COUNT: (kBlockN * WARP_K) / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = IS_HEADDIM_128 ? 8: 4; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = IS_HEADDIM_128 ? lane_id & 3: lane_id & 7; // 当前 lane 负责这个 warp 的第几个 dwordx4 或者 dwordx2
// 为了解决 ds_read2_b32 的 bank 冲突, 需要交换 0-15 线程的一些读取地址, 4 个线程为一组 (0, 1, 2, 3 ---> 0, 2, 1, 3 的写入位置, 从而满足 ds_read2_b32 offset32 的要求)
// 非 headdim 的话, 则交换 0-7 线程的一些读取地址, 4 个线程为一组
int base = IS_HEADDIM_128 ? (laneid_shfl_2 & 0xc): (laneid_shfl_3 & 0x4); // 第几个4线程组的最小id
int tail = IS_HEADDIM_128 ? (laneid_shfl_2 & 0x3): (laneid_shfl_3 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = IS_HEADDIM_128 ? &inline_buffer_load_dwordx4_lds<Element, 2>: &inline_buffer_load_dwordx2_lds<Element, 2>;
#else
constexpr int WARP_K = 32;
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + v_lane_headdim_n_idx * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
int n_loop = 0;
int v_block_buffer_load_global_offset = n_loop * kBlockN;
for(int load = 0, warp_loop = WARP_ID; load < V_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int v_warp_buffer_load_k_id = warp_loop / (kBlockN / 32);
int s_offset = v_block_buffer_load_global_offset / 2;
int seqlen_pos = v_lane_seq_k_idx + v_warp_buffer_load_k_id * READ_ONCE_LINES;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_kv_offset - 1);
}
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + seqlen_pos * seqlen_v_stride) / 2;
int v_lds_offset = v_warp_buffer_load_k_id * READ_ONCE_COUNT / 2;
BUFFER_LOAD_FUNC(v_lds + (stage_id * STAGES) * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
__builtin_amdgcn_sched_barrier(0);
if (IS_HEADDIM_128) {
#if !defined(USE_BUFFER_LOAD_DWORDX4)
// for ZD, double prefetch bring degradation; for BMZ, may bring improvement
int n_loop = 0;
for(int load = 0, warp_loop = WARP_ID; load < V_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int v_warp_buffer_load_k_id = warp_loop / (kBlockN / 32);
int s_offset = v_block_buffer_load_global_offset / 2;
int seqlen_pos = v_lane_seq_k_idx + v_warp_buffer_load_k_id * READ_ONCE_LINES + WARP_K;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_kv_offset - 1);
}
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + seqlen_pos * seqlen_v_stride) / 2;
int v_lds_offset = v_warp_buffer_load_k_id * READ_ONCE_COUNT / 2;
BUFFER_LOAD_FUNC(v_lds + (stage_id * STAGES + 1) * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
__builtin_amdgcn_sched_barrier(0);
#endif
}
}
\ No newline at end of file
#pragma once
#include "qk_gemm_prefetch_v_headdim128.h"
#define USE_DS_OVERLAP_MMAC
namespace flash {
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void qk_gemm_prefetch_v(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / 32)][4],
int WARP_ID,
int seqlen_k_stride,
int seqlen_v_stride,
int max_seq_k_offset= - 1) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
union_vec4_f16x2<Element> k_reg[STAGES * (WARP_N * kBlockK) / (32 * 32) * 2];
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int q_lds_load_num = kBlockM * kBlockK / (4 * 32);
constexpr int Q_LOAD_REQUESTS = q_lds_load_num / WARP_NUM;
constexpr int k_lds_load_num = WARP_N * kBlockK / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
constexpr int QK_LOOP_COUNT = kHeadDim / kBlockK;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int q_warp_m_id = WARP_ID & ((kBlockM / WARP_M) - 1);
int k_warp_n_id = WARP_ID & (WARP_N / WARP_N - 1);
int q_ds_read_offset = WARP_ID * (WARP_M / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
int stage_id = 0;
#pragma unroll
for (int i = 0; i < (kBlockN / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
inline_vgpr4_init_zero(s_reg[i][j]);
}
}
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for(int k_loop = K_LOOP_START; k_loop < QK_LOOP_COUNT; ++k_loop) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait<0>();
}
#else
buffer_load_lds_dwordx1_wait<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
// 先把第 2 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 2)) {
// stage_id = 0;
if constexpr (STAGES == 2) {
int k_loop = 0;
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
}
}
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = last_stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<0>();
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
// 第 2 阶段主循环
if constexpr (kBlockN >= (WARP_N * 2)) {
if constexpr (STAGES == 2) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
for(int k_loop = K_LOOP_START; k_loop<QK_LOOP_COUNT; ++k_loop) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait<0>();
}
#else
buffer_load_lds_dwordx1_wait<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 保留第 2 阶段最后一波数据实际的 stage_id
last_stage_id = stage_id ^ 1;
// 先把第 3 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 3)) {
// stage_id = 0;
if constexpr (STAGES == 2) {
int k_loop = 0;
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 2 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
}
}
// 等待第 2 阶段最后一波数据返回做计算
if constexpr (kBlockN >= (WARP_N * 2)) {
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = last_stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_dwordx1_wait<V_LOAD_REQUESTS>(); // when use prefetch V
if constexpr (kBlockN >= (WARP_N * 3)) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<0>();
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 第 3 阶段主循环
if constexpr (kBlockN >= (WARP_N * 3)) {
if constexpr (STAGES == 2) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
for(int k_loop = K_LOOP_START; k_loop<QK_LOOP_COUNT; ++k_loop) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 2 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait<0>();
}
#else
buffer_load_lds_dwordx1_wait<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 保留第 3 阶段最后一波数据实际的 stage_id
last_stage_id = stage_id ^ 1;
// 先把第 4 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 4)) {
// stage_id = 0;
if constexpr (STAGES == 2) {
int k_loop = 0;
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 3 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
}
}
// 等待第 3 阶段最后一波数据返回做计算
if constexpr (kBlockN >= (WARP_N * 3)) {
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = last_stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_dwordx1_wait<V_LOAD_REQUESTS>(); // when use prefetch V
if constexpr (kBlockN >= (WARP_N * 4)) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<0>();
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 第 4 阶段主循环
if constexpr (kBlockN >= (WARP_N * 4)) {
if constexpr (STAGES == 2) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
for(int k_loop = K_LOOP_START; k_loop<QK_LOOP_COUNT; ++k_loop) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 3 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait<0>();
}
#else
buffer_load_lds_dwordx1_wait<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
// a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 等待第 4 阶段最后一波数据返回做计算
if constexpr (kBlockN >= (WARP_N * 4)) {
if constexpr (STAGES == 2) {
stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_dwordx1_wait<V_LOAD_REQUESTS>(); // when use prefetch V
buffer_load_lds_dwordx1_wait<0>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int q_tile_id = (QK_LOOP_COUNT - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
constexpr int V_LOAD_REQUESTS = (WARP_M * kBlockK) / (4 * 32) / WARP_NUM;
if constexpr (STAGES == 2) {
if constexpr (Is_even_MN)
prefetch_v_to_lds<kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 0, Element, Is_even_MN>(gV, v_lds, WARP_ID, seqlen_v_stride);
else
prefetch_v_to_lds<kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 0, Element, Is_even_MN>(gV, v_lds, WARP_ID, seqlen_v_stride, max_seq_k_offset);
}
} // qk_gemm
} // namespace flash
#pragma once
#include "qk_gemm_utils.h"
namespace flash {
// #define USE_SCHEDULE_0_INIT
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void qk_gemm_prefetch_v_headdim128(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (kBlockN / 32)][4],
int WARP_ID,
int seqlen_k_stride,
int seqlen_v_stride,
int max_seq_k_offset=-1) {
if constexpr (kHeadDim == 128) {
static_assert(STAGES == 2 and "For double prefetch in headdim 128/64, only STAGES=2 is supported!\n");
static_assert(kBlockN >= 64 and "For double prefetch in headdim 128/64, only BLOCK_N >= 64 is supported!\n");
}
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
constexpr int QK_LOOP_COUNT = kHeadDim / kBlockK;
union_vec4_f16x2<Element> k_reg[(WARP_N * kBlockK) / (32 * 32) * 2];
union_vec4_f16x2<Element> k_reg_tmp[(WARP_N * kBlockK) / (32 * 32) * 2];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = WARP_ID & (WARP_N / WARP_N - 1);
int q_ds_read_offset = WARP_ID * (WARP_M / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15>>1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15>>1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
constexpr int q_lds_load_num = kBlockM * kBlockK / (4 * 32);
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int Q_LOAD_REQUESTS = q_lds_load_num / WARP_NUM;
constexpr int k_lds_load_num = WARP_N * kBlockK / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
#pragma unroll
for (int i = 0; i < (kBlockN / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
inline_vgpr4_init_zero(s_reg[i][j]);
}
}
// 准备 load 下一轮数据
int stage_id = 0;
stage_id ^= 1;
// 第 1, 2 阶段主循环
constexpr int K_LOOP_START = 1;
for(int k_loop = K_LOOP_START; k_loop < QK_LOOP_COUNT; ++k_loop) {
// load 下一轮的第 1 阶段数据
{
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 提前 load 下一轮的第 2 阶段数据, 让 load 指令飞得更久一点
{
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 切到计算的轮次
stage_id ^= 1;
// 在 wait 之前提前计算这部分 lds load 的偏移量
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 看 threadtrace 的话, 最前面发出去 2 + 2 笔请求, 刚才又发出去 2 + 2 笔请求, 现在等第一个 2 笔请求, 需要 wait vmcnt(6)
buffer_load_lds_dwordx1_wait<3 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// 第一笔的数据 lds -> vgpr
{
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
// 预取第二阶段的第一笔数据, 先计算需要的 lds offset
if constexpr (kBlockN >= (WARP_N * 2)) {
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
// 预取第二阶段的第一笔数据, overlap 上面的 ds 读取的时延, 把 ds 指令提前发出去
__builtin_amdgcn_sched_barrier(0);
// 发出去 4 笔 K_LOAD_REQUESTS 的请求, 这里等待第 2 笔数据
buffer_load_lds_dwordx1_wait < 2 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
}
// 发下去 4 笔 ds_read2_b32 请求, 先等前两笔结果的返回, 但由于预发了 4 笔 ds_read2_b32 请求, 所以 2 + 4 = 6
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
// 等后两笔 ds_read2_b32 的结果
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 第 2 笔数据 lds -> vgpr
if constexpr (kBlockN >= (WARP_N * 2)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 保留第 1,2 阶段最后一波数据需要计算的 id
int last_stage_id = stage_id ^ 1;
// 这里做了预取
// 先把第 3 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 3)) {
int k_loop = 0;
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 2 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 这里做了预取
// 先把第 4 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 4)) {
int k_loop = 0;
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 3 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 等待第 1,2 阶段最后一波数据返回做计算
{
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (last_stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
// 原来发了 4 笔 K_LOAD_REQUESTS, 已经算了 2 笔, 这里在等第 3 笔, 但是预取了 3,4 阶段的数据, 所以 1 + 2 = 3 个 K_LOAD_REQUESTS
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 3)) {
buffer_load_lds_dwordx1_wait<3 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<1 * K_LOAD_REQUESTS>();
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
if constexpr (kBlockN >= (WARP_N * 2)) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
// 等待最后 1 笔数据的返回, 但是预取了 3, 4 阶段的数据
if constexpr (kBlockN >= (WARP_N * 3)) {
buffer_load_lds_dwordx1_wait < 2 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<0>();
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 等第 2 阶段的最后一波数据回来计算
if constexpr (kBlockN >= (WARP_N * 2)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/1 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 第 3, 4 阶段主循环
if constexpr (kBlockN >= (WARP_N * 3)) {
// 切到 load 数据的轮次
stage_id ^= 1;
for(int k_loop = K_LOOP_START; k_loop < QK_LOOP_COUNT; ++k_loop) {
// 发第 3 阶段的下一个轮次的 load 请求
{
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 2 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 提前把第 4 阶段的 load 发下去, 放 load 指令飞的更久一点
if constexpr (kBlockN >= (WARP_N * 4)) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 3 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 切到计算的轮次
stage_id ^= 1;
// 在 wait 前先计算好 lds offset
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait<3 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
if constexpr (kBlockN >= (WARP_N * 4)) {
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait < 2 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 等待第 4 阶段的数据
if constexpr (kBlockN >= (WARP_N * 4)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
}
// 保留第 3, 4 阶段最后一波数据实际的 stage_id
last_stage_id = stage_id ^ 1;
// 先把第 5 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 5)) {
int k_loop = 0;
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 4 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 先把第 6 阶段的 load 指令先发出去
if constexpr (kBlockN >= (WARP_N * 6)) {
int k_loop = 0;
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 4 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 等待第 3, 4 阶段最后一波数据返回做计算
if constexpr (kBlockN >= (WARP_N * 3)) {
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (last_stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 5)) {
buffer_load_lds_dwordx1_wait<3 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<1 * K_LOAD_REQUESTS>();
}
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
if constexpr (kBlockN >= (WARP_N * 4)) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 5)) {
buffer_load_lds_dwordx1_wait < 2 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait<0 * K_LOAD_REQUESTS>();
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/2 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 等第 4 阶段 load 指令的最后一波数据回来
if constexpr (kBlockN >= (WARP_N * 4)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/3 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 第 5, 6 阶段主循环
if constexpr (kBlockN >= (WARP_N * 5)) {
// 切到 load 数据的轮次
stage_id ^= 1;
for(int k_loop = K_LOOP_START; k_loop < QK_LOOP_COUNT; ++k_loop) {
// 发第 5 阶段的下一个轮次的 load 请求
{
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 4 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 提前把第 6 阶段的 load 发下去, 放 load 指令飞的更久一点
if constexpr (kBlockN >= (WARP_N * 6)) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32)) ; ;
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + 4 * WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 切到计算的轮次
stage_id ^= 1;
// 在 wait 前先计算好 lds offset
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = (stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait<3 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
if constexpr (kBlockN >= (WARP_N * 4)) {
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait < 2 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 等待第 6 阶段的数据
if constexpr (kBlockN >= (WARP_N * 6)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = k_loop - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
}
// 保留第 5, 6 阶段最后一波数据实际的 stage_id
last_stage_id = stage_id ^ 1;
// 等待第 5, 6 阶段最后一波数据返回做计算
if constexpr (kBlockN >= (WARP_N * 5)) {
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (last_stage_id * STAGES) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait<1 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
if constexpr (kBlockN >= (WARP_N * 6)) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] += 1 * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait<0 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg_tmp[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(6)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(4)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/4 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// 等第 6 阶段 load 指令的最后一波数据回来
if constexpr (kBlockN >= (WARP_N * 6)) {
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
int k_loop_idx = QK_LOOP_COUNT - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg_tmp[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[(/*n_loop*/5 * (WARP_N / 32) + n_idx) * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
asm volatile("s_barrier ; sync before load in the coming round");
}
}
__builtin_amdgcn_sched_barrier(0);
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
constexpr int V_LOAD_REQUESTS = (WARP_M * kBlockK) / (4 * 32) / WARP_NUM;
if constexpr (STAGES == 2) {
if constexpr (Is_even_MN)
prefetch_v_to_lds<kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 0, Element, Is_even_MN>(gV, v_lds, WARP_ID, seqlen_v_stride);
else
prefetch_v_to_lds<kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 0, Element, Is_even_MN>(gV, v_lds, WARP_ID, seqlen_v_stride, max_seq_k_offset);
}
} // qk_gemm
} // namespace flash
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "pv_gemm_utils.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_q_to_vgpr(
vec4_uint gQ,
Element* q_lds,
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4],
int WARP_ID,
int seqlen_q_stride,
int max_seq_q_offset=-1) {
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int q_lds_load_num = kBlockM * kBlockK / (4 * 32);
constexpr int Q_LOAD_REQUESTS = q_lds_load_num / WARP_NUM;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = ((lane_id >> 4) & 1) * 2 + ((lane_id >> 4) >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int q_lane_head_dim_idx = lane_id & 15;
int stage_id = 0;
{
int k_loop = 0;
// global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * kBlockK;
const int q_lds_load_num = kBlockM * kBlockK / (4 * 32);
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0,warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int s_offset = q_block_buffer_load_global_offset / 2;
int seqlen_pos = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_q_offset - 1);
}
int v_offset = seqlen_pos * seqlen_q_stride / 2 + q_lane_head_dim_idx;
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, s_offset, v_offset);
}
}
stage_id ^= 1;
for(int k_loop = 1; k_loop < (kHeadDim / kBlockK); ++k_loop) {
// global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0,warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int s_offset = q_block_buffer_load_global_offset / 2;
int seqlen_pos = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_q_offset - 1);
}
int v_offset = seqlen_pos * seqlen_q_stride / 2 + q_lane_head_dim_idx;
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, s_offset, v_offset);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + (WARP_ID * (WARP_M / 32) + m_idx) * (32 * 17) + j * 2 + i * 32 + (lane_id & 1) * 16 + ((lane_id & 15) >> 1) * 64 + /*padding*/ ((lane_id & 15) >> 1) + ((lane_id / 16) & 1) * 8 + (lane_id / 32);
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[(k_loop - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i][j]);
}
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + (WARP_ID * (WARP_M / 32) + m_idx) * (32 * 17) + j * 2 + i * 32 + (lane_id & 1) * 16 + ((lane_id & 15) >> 1) * 64 + /*padding*/ ((lane_id & 15) >> 1) + ((lane_id / 16) & 1) * 8 + (lane_id / 32);
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i][j]);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kBlockN, int kBlockK, int WARP_NUM, int WARP_N, typename Element, bool Is_even_MN, int STAGES=2>
__forceinline__ __device__ void prefetch_k_to_lds(
vec4_uint gK,
Element* k_lds,
int WARP_ID,
int seqlen_k_stride,
int max_seq_k_offset=-1) {
// constexpr int WARP_NUM = kBlockN / WARP_N;
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int stage_id = 0;
int k_loop = 0;
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = stage_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (kHeadDim == 128 or kHeadDim == 64) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (stage_id * STAGES + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = warp_loop & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int s_offset = k_block_buffer_load_global_offset / 2;
int seqlen_pos = k_warp_buffer_load_n_id * 4 + qk_lane_m_idx + WARP_N;
if constexpr (not Is_even_MN) {
seqlen_pos = min(seqlen_pos, max_seq_k_offset - 1);
}
int v_offset = seqlen_pos * seqlen_k_stride / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include "utils.h"
using namespace flash;
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int BLOCK_ROW_STRIDE, bool Is_even_MN>
inline __device__ void apply_dropout(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k, const int col_idx_offset_,
unsigned long long seed, unsigned long long offset, uint32_t p_dropout_in_8bits_value,
union_vec2_uint rowcol, uint32_t* dropout_debug_count) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
// prepare 4 uint for 16 uint8
union_vec4_uint random_uint4;
for (int mi = 0; mi < (WARP_M / 32); ++mi, rowcol.u32.x += BLOCK_ROW_STRIDE) {
#pragma unroll
for (uint32_t ni = 0; ni < (WARP_N / 32); ++ni, ++rowcol.u32.y) {
// for each 16 elements, generate 16 int8 -> 4 u32
random_uint4.u32 = flash::philox(seed, rowcol.u64, offset);
#pragma unroll
for(uint32_t min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
if constexpr (Is_even_MN) {
#pragma unroll
for(uint32_t min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
uint32_t cur_pos = (min_tile_n * 2 + min_tile_m) * 4 + vec_idx;
uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since hcu has no compare instructions with 8/16 bits
if (cur_rand > p_dropout_in_8bits_value) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = 0x0;
#ifdef FA_DEBUG
atomicAdd(dropout_debug_count, 1);
#endif
}
}
} else if constexpr (not Is_even_MN) {
const int col_idx = col_idx_base + vec_idx * 8;
if (col_idx < max_seqlen_k) {
#pragma unroll
for(uint32_t min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
uint32_t cur_pos = (min_tile_n * 2 + min_tile_m) * 4 + vec_idx;
uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since hcu has no compare instructions with 8/16 bits
if (cur_rand > p_dropout_in_8bits_value) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = 0x0;
#ifdef FA_DEBUG
atomicAdd(dropout_debug_count, 1);
#endif
}
}
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask_causal(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q); // attention, when max_seqlen_k == max_seqlen_q, vgpr can be reduced again
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <bool HasWSLeft=true, typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_mask_local(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q,
const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ?
-INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_alibi(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, float gAlibi) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] += gAlibi * (col_idx - row_idx);
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void thread_reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = -INFINITY; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = 0; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
}
}
template<typename Operator, typename DataType, int WARP_M>
__device__ inline void quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); mi++) {
dst[mi] = Allreduce<64>::run(src[mi], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if constexpr (OpType == 0) { // sum
if constexpr (zero_init == true) {
thread_reduce_sum<true, Operator, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op);
quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
thread_reduce_sum<false, Operator, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op, summary_cur);
quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
} else if constexpr (OpType == 1) { // max
if constexpr (zero_init == true) {
thread_reduce_max<true, Operator, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op);
quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
thread_reduce_max<false, Operator, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op, summary_cur);
quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
}
}
// zero_init==true, max is current max_score, max_cur=nullptr
// zero_init==true, max is prev max_score, max_cur!=nullptr
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if constexpr (zero_init == true) {
reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op);
} else {
reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_sum(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if constexpr (zero_init == true) {
reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op);
} else {
reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], const DataType1 *max, const float scale) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
const float max_scaled = (max[mi * 2].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi * 2].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
__float2 neg_max_scaled_pair = {-max_scaled, -max_scaled};
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int qk_tile_id = mi + ni * (WARP_M / 32);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
}
#else
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1, int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N>
inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_M / 32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4], float softmax_scale_log2) {
if constexpr (Is_first) {
reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max);
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, softmax_scale_log2);
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum);
} else {
DataType1 scores_max_cur[(WARP_M / 32)];
reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi * 2].f32[min_tile_m]
: (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi * 2].f32[min_tile_m] *= scores_scale;
__float2 scores_scale_pair = {scores_scale, scores_scale};
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + mi + ni * (WARP_M / 32);
int mmac_id = min_tile_n * 2 + min_tile_m;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
}
#else
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
acc_o[pv_tile_id][mmac_id].f32[vec_idx] *= scores_scale;
}
#endif
}
}
}
}
}
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max_cur, softmax_scale_log2);
DataType1 scores_sum_cur[(WARP_M / 32)];
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
scores_sum_cur[mi].u64 = 0x0;
}
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum_cur);
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
#else // for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__))
inlineasm_fa_v_mov_b64(
scores_max[mi].u64,
scores_max_cur[mi].u64
);
#else
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
#endif
}
}
};
template <int WARP_M, int WARP_N, typename Element, typename ElementAccum>
inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (WARP_N / 32)][4], union_vec4_fp32 s_reg[(WARP_M / 32) * (WARP_N / 32)][4]) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
#else
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
}
}
}
\ No newline at end of file
#pragma once
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include "numeric_types.h"
#include "wait.h"
#include "intrinsic.h"
namespace flash {
__forceinline__ __device__ void raise_priority(const int priority_level=2) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio %0":: "n"(priority_level));
__builtin_amdgcn_sched_barrier(0);
}
__forceinline__ __device__ void lower_priority() {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
__builtin_amdgcn_sched_barrier(0);
}
inline __device__ constexpr int ceil_div(int const& a, int const& b) {
return (a + b - 1) / b;
}
inline __device__ constexpr int floor_div(int const& a, int const& b) {
return a / b;
}
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) {
T res = (x + y);
return res;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
__forceinline__ __device__ T __shfl_xor_tmp(T x, const int lane_mask) {
int lane_id = threadIdx.x & 63;
int index = (lane_id ^ lane_mask) << 2;
int res = __builtin_amdgcn_ds_bpermute(index, *(int*)&x); // attention, __builtin only support int
return *(T*)&res;
}
template<typename T>
__forceinline__ __device__ T __shfl_swap16(T x) {
int result = __builtin_amdgcn_ds_swizzle(*(int*)&x, 0x401F);
return *(T*)&result;
}
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64);
template<typename Operator>
static __device__ inline union_vec2_fp32 run(union_vec2_fp32 x, Operator &op) {
union_vec2_fp32 res;
if constexpr (std::is_same<Operator, SumOp<float> >::value) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_swap16(x.f32[0]); // __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_swap16(x.f32[1]); // __shfl_xor_tmp(x.f32[1], 16);
res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
res.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 16);
#endif
}
else if constexpr (std::is_same<Operator, MaxOp<float> >::value) {
x.f32[0] = op(x.f32[0], __shfl_xor_tmp(x.f32[0], 32));
x.f32[1] = op(x.f32[1], __shfl_xor_tmp(x.f32[1], 32));
res.f32[0] = op(x.f32[0], __shfl_swap16(x.f32[0]));
res.f32[1] = op(x.f32[1], __shfl_swap16(x.f32[1]));
}
return res;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<const int kHeadDim, typename T, bool Do_CacheSwizzle=true>
__device__ __forceinline__ vec4_uint prepare_for_buffer_load(T* ptr) {
vec4_uint res;
*(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
if constexpr (Do_CacheSwizzle) {
if constexpr (kHeadDim == 128) {
res[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
} else if constexpr (kHeadDim == 192) {
res[1] += 0x41800000; // stride 192Bytes and change tagram
} else if constexpr (kHeadDim == 64) {
res[1] += 0x40800000; // stride 128Bytes and change tagram
}
}
res[2] = 0x80000000;
res[3] = 0x00020000;
return res;
}
// for matrix_load
template<const int kHeadDim, typename T>
__device__ __forceinline__ vec4_uint prepare_for_matrix_load(T* ptr) {
vec4_uint res;
*(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
res[2] = 0x0;
res[3] = 0x0;
return res;
}
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void attention_initialize(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4]
) {
#pragma unroll
for (int i = 0; i < M_WARP_COUNT; ++i) {
scores_max[i].f32[0] = -INFINITY;
scores_max[i].f32[1] = -INFINITY;
scores_sum[i].f32[0] = 0;
scores_sum[i].f32[1] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#elif defined(__gfx938__) || defined(__gfx946__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
#else
acc_o[i][min_tile_n * 2 + min_tile_m].f32[0] = 0;
acc_o[i][min_tile_n * 2 + min_tile_m].f32[1] = 0;
acc_o[i][min_tile_n * 2 + min_tile_m].f32[2] = 0;
acc_o[i][min_tile_n * 2 + min_tile_m].f32[3] = 0;
#endif
}
}
}
}
} // namespace flash
#pragma once
#include <vector>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "numeric_types.h"
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#define USE_BUFFER_LOAD_DWORDX4
// #define USE_BUFFER_LOAD_DWORDX2
#endif
template<class DataType>
__forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resource) {
int container;
int offset = 0;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_nop 4\n\t"
"buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc\n\t"
: "=v"(container)
: "v"(offset), "s"(buffer_resource)
);
__builtin_amdgcn_sched_barrier(0);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx2_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_s_load_dword(DataType &s_data, const uint64_t global_addr, int gvOffset_s) {
int offset_s = gvOffset_s << shfl_count;
asm volatile("s_load_dword %0, %1 \n"
: "=s"(s_data)
: "s"(global_addr)
:"memory");
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
uint32_t offset_s = gvOffset_s << shfl_count;
uint32_t offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void safe_inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &offset_s, const int &offset_v) {
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int __offset_s = offset_s << shfl_count;
int __offset_v = offset_v << shfl_count;
asm volatile("s_nop 3\n\t"
"s_mov_b32 m0, %1\n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds\n"
:: "v"(__offset_v), "s"(lds_addr_per_wave), "s"(global_addr), "s"(__offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_glc_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l1_glc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l2_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
asm volatile("s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<typename src_type=half_t, typename dst_type=float, const int dword_count=1, const int auxilariy=0>
__forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
constexpr int bytes_per_element = sizeof(dst_type);
dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
dword_count * 4,
gvOffset_v * bytes_per_element,
gvOffset_s * bytes_per_element,
0, /* immediate offset, instruction offset */
auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
template<typename src_type=half_t, typename dst_type=float>
__forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
constexpr int bytes_per_element = sizeof(dst_type);
dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
4,
gvOffset_v * bytes_per_element,
gvOffset_s * bytes_per_element,
0, /* immediate offset, instruction offset */
11 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
template<class DataType, const int shfl_count>
__forceinline__ __device__ void inline_buffer_store_dword(const DataType v_data, const int &v_offset, const vec4_uint global_addr, const int &s_offset, const int s_constant=0) {
int v_offset_bytes = v_offset << shfl_count;
int s_offset_bybes = s_offset << shfl_count;
const int s_constant_bytes = s_constant << shfl_count;
asm volatile(
"buffer_store_dword %0, %1, %2, %3, offen offset:%4 \n"
:: "v"(v_data), "v"(v_offset_bytes), "s"(global_addr), "s"(s_offset_bybes), "B"(s_constant_bytes)
:);
}
template<class DataType, const int shfl_count>
__forceinline__ __device__ void inline_buffer_store_dwordx4(const DataType v_data, const int &v_offset, const vec4_uint global_addr, const int &s_offset, const int s_constant=0) {
int v_offset_bytes = v_offset << shfl_count;
int s_offset_bybes = s_offset << shfl_count;
const int s_constant_bytes = s_constant << shfl_count;
asm volatile(
"buffer_store_dwordx4 %0, %1, %2, %3, offen offset:%4 \n"
:: "v"(v_data), "v"(v_offset_bytes), "s"(global_addr), "s"(s_offset_bybes), "B"(s_constant_bytes)
:);
}
template<class DataType, const int shfl_count>
__forceinline__ __device__ void inline_buffer_store_dword_glc_slc(DataType v_data, int &v_offset, vec4_uint global_addr, int &s_offset, const int s_constant=0) {
int v_offset_bytes = v_offset << shfl_count;
int s_offset_bybes = s_offset << shfl_count;
const int s_constant_bytes = s_constant << shfl_count;
#if !defined(__gfx916__) && !defined(__gfx926__)
asm volatile(
"buffer_store_dword %0, %1, %2, %3, offen offset:%4 glc slc\n"
:: "v"(v_data), "v"(v_offset_bytes), "s"(global_addr), "s"(s_offset_bybes), "B"(s_constant_bytes)
:);
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_ds_read_b16_no_wait_bytes(const int &lds_offset, VEC &reg_val) {
asm volatile(
"ds_read_u16 %0 ,%1\n"
: "=v"(reg_val)
: "v"(lds_offset)
:);
}
template<typename VEC>
__forceinline__ __device__ void inline_ds_read_b32_no_wait(VEC *const shared_addr, const int &lds_offset, VEC &reg_val) {
int ldsAddr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4;
asm volatile(
"ds_read_b32 %0, %1\n"
: "=v"(reg_val)
: "v"(ldsAddr)
:);
}
template<typename VEC>
__forceinline__ __device__ void inline_ds_read_b32_no_wait_bytes(const int &lds_offset, VEC &reg_val) {
asm volatile(
"ds_read_b32 %0, %1\n"
: "=v"(reg_val)
: "v"(lds_offset)
:);
}
template<typename VEC, typename dwordx2>
__forceinline__ __device__ void inline_ds_read2_b32_no_wait(VEC *shared_addr, const int &lds_offset, dwordx2& reg_val, const int offset1) {
int ldsAddr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4;
asm volatile(
"ds_read2_b32 %0 ,%1 offset0:0 offset1:%2\n"
: "=v"(reg_val)
: "v"(ldsAddr), "B"(offset1)
:);
}
template<typename dwordx2>
__forceinline__ __device__ void inline_ds_read2_b32_no_wait_bytes(const int &lds_offset, dwordx2& reg_val, const int offset1) {
asm volatile(
"ds_read2_b32 %0, %1 offset0:0 offset1:%2\n"
: "=v"(reg_val)
: "v"(lds_offset), "B"(offset1)
:);
}
template<typename dwordx2>
__forceinline__ __device__ void inlineasm_fa_ds_read2_b32(float* shared_addr, const int &lds_offset, dwordx2& reg_val, const int offset0, const int offset1) {
int lds_addr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4;
asm volatile(
"ds_read2_b32 %0, %1 offset0:%2 offset1:%3\n"
: "=v"(reg_val)
: "v"(lds_addr), "B"(offset0), "B"(offset1)
:);
}
__forceinline__ __device__ void inline_ds_write2_b32_no_wait_bytes(float* shared_addr, int lds_offset, const __float2& reg_val, const int offset0, const int offset1) {
int lds_addr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4;
asm volatile(
"ds_write2_b32 %0, %1, %2 offset0:%3 offset1:%4\n"
: "=v"(lds_addr)
: "v"(reg_val[0]), "v"(reg_val[1]), "B"(offset0), "B"(offset1)
:);
}
template<typename VEC>
__forceinline__ __device__ void inline_ds_read_b32_wait(VEC *const shared_addr, const int &lds_offset, VEC &reg_val) {
reg_val = shared_addr[lds_offset];
}
template<typename VEC>
__forceinline__ __device__ void inlineasm_ds_read_b128(int lds_offset, VEC& data) {
asm volatile("ds_read_b128 %0, %1\n"
: "=v"(data)
: "s"(lds_offset)
:);
}
template<typename VEC>
__forceinline__ __device__ void inlineasm_ds_write_b128(int lds_offset, VEC& data) {
asm volatile("ds_write_b128 %0, %1\n"
:: "s"(lds_offset), "v"(data)
:);
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr_init_zero(VEC &dst, const int idx) {
asm ("v_mov_b32 %0, 0x0"
: "=v"(dst[idx])
:);
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0"
: "=v"(dst)
:);
#else
dst = 0x0;
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
: "=v"(dst.u64[0]), "=v"(dst.u64[1])
:);
#else
asm ("v_mov_b32 %0, 0x0\n\t"
"v_mov_b32 %1, 0x0\n\t"
"v_mov_b32 %2, 0x0\n\t"
"v_mov_b32 %3, 0x0\n\t"
: "=v"(dst.f32[0]), "=v"(dst.f32[1]), "=v"(dst.f32[2]), "=v"(dst.f32[3])
:);
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t"
"v_mov_b64 %3, 0x0\n\t"
"v_mov_b64 %4, 0x0\n\t"
"v_mov_b64 %5, 0x0\n\t"
"v_mov_b64 %6, 0x0\n\t"
"v_mov_b64 %7, 0x0\n\t"
"v_mov_b64 %8, 0x0\n\t"
"v_mov_b64 %9, 0x0\n\t"
"v_mov_b64 %10, 0x0\n\t"
"v_mov_b64 %11, 0x0\n\t"
"v_mov_b64 %12, 0x0\n\t"
"v_mov_b64 %13, 0x0\n\t"
"v_mov_b64 %14, 0x0\n\t"
"v_mov_b64 %15, 0x0\n\t"
"v_mov_b64 %16, 0x0\n\t"
"v_mov_b64 %17, 0x0\n\t"
"v_mov_b64 %18, 0x0\n\t"
"v_mov_b64 %19, 0x0\n\t"
"v_mov_b64 %20, 0x0\n\t"
"v_mov_b64 %21, 0x0\n\t"
"v_mov_b64 %22, 0x0\n\t"
"v_mov_b64 %23, 0x0\n\t"
"v_mov_b64 %24, 0x0\n\t"
"v_mov_b64 %25, 0x0\n\t"
"v_mov_b64 %26, 0x0\n\t"
"v_mov_b64 %27, 0x0\n\t"
"v_mov_b64 %28, 0x0\n\t"
"v_mov_b64 %29, 0x0\n\t"
"v_mov_b64 %30, 0x0\n\t"
"v_mov_b64 %31, 0x0\n"
: "=v"(s_reg[0][0].u64[0]), "=v"(s_reg[0][0].u64[1]), "=v"(s_reg[0][1].u64[0]), "=v"(s_reg[0][1].u64[1]), "=v"(s_reg[0][2].u64[0]), "=v"(s_reg[0][2].u64[1]), "=v"(s_reg[0][3].u64[0]), "=v"(s_reg[0][3].u64[1]), "=v"(s_reg[1][0].u64[0]), "=v"(s_reg[1][0].u64[1]), "=v"(s_reg[1][1].u64[0]), "=v"(s_reg[1][1].u64[1]), "=v"(s_reg[1][2].u64[0]), "=v"(s_reg[1][2].u64[1]), "=v"(s_reg[1][3].u64[0]), "=v"(s_reg[1][3].u64[1]), "=v"(s_reg[2][0].u64[0]), "=v"(s_reg[2][0].u64[1]), "=v"(s_reg[2][1].u64[0]), "=v"(s_reg[2][1].u64[1]), "=v"(s_reg[2][2].u64[0]), "=v"(s_reg[2][2].u64[1]), "=v"(s_reg[2][3].u64[0]), "=v"(s_reg[2][3].u64[1]), "=v"(s_reg[3][0].u64[0]), "=v"(s_reg[3][0].u64[1]), "=v"(s_reg[3][1].u64[0]), "=v"(s_reg[3][1].u64[1]), "=v"(s_reg[3][2].u64[0]), "=v"(s_reg[3][2].u64[1]), "=v"(s_reg[3][3].u64[0]), "=v"(s_reg[3][3].u64[1])
:);
#else
uint64_t pk_zero = 0x0;
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
s_reg[i][j].u64[0] = pk_zero;
s_reg[i][j].u64[1] = pk_zero;
}
}
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t"
"v_mov_b64 %3, 0x0\n\t"
"v_mov_b64 %4, 0x0\n\t"
"v_mov_b64 %5, 0x0\n\t"
"v_mov_b64 %6, 0x0\n\t"
"v_mov_b64 %7, 0x0\n\t"
"v_mov_b64 %8, 0x0\n\t"
"v_mov_b64 %9, 0x0\n\t"
"v_mov_b64 %10, 0x0\n\t"
"v_mov_b64 %11, 0x0\n\t"
"v_mov_b64 %12, 0x0\n\t"
"v_mov_b64 %13, 0x0\n\t"
"v_mov_b64 %14, 0x0\n\t"
"v_mov_b64 %15, 0x0\n\t"
: "=v"(s_reg[0][0].u64[0]), "=v"(s_reg[0][0].u64[1]), "=v"(s_reg[0][2].u64[0]), "=v"(s_reg[0][2].u64[1]), "=v"(s_reg[1][0].u64[0]), "=v"(s_reg[1][0].u64[1]), "=v"(s_reg[1][2].u64[0]), "=v"(s_reg[1][2].u64[1]), "=v"(s_reg[2][0].u64[0]), "=v"(s_reg[2][0].u64[1]), "=v"(s_reg[2][2].u64[0]), "=v"(s_reg[2][2].u64[1]), "=v"(s_reg[3][0].u64[0]), "=v"(s_reg[3][0].u64[1]), "=v"(s_reg[3][2].u64[0]), "=v"(s_reg[3][2].u64[1])
:);
#else
uint64_t pk_zero = 0x0;
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int j = 0; j < 4; j += 2) {
s_reg[i][j].u64[0] = pk_zero;
s_reg[i][j].u64[1] = pk_zero;
}
}
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t"
"v_mov_b64 %3, 0x0\n\t"
"v_mov_b64 %4, 0x0\n\t"
"v_mov_b64 %5, 0x0\n\t"
"v_mov_b64 %6, 0x0\n\t"
"v_mov_b64 %7, 0x0\n\t"
: "=v"(s_reg[0][0].u64[0]), "=v"(s_reg[0][0].u64[1]), "=v"(s_reg[0][1].u64[0]), "=v"(s_reg[0][1].u64[1]), "=v"(s_reg[0][2].u64[0]), "=v"(s_reg[0][2].u64[1]), "=v"(s_reg[0][3].u64[0]), "=v"(s_reg[0][3].u64[1])
:);
#endif
}
template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_1x2x4(VEC s_reg[1][4]) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t"
"v_mov_b64 %3, 0x0\n\t"
: "=v"(s_reg[0][0].u64[0]), "=v"(s_reg[0][0].u64[1]), "=v"(s_reg[0][2].u64[0]), "=v"(s_reg[0][2].u64[1])
:);
#endif
}
// to simplify f32 -> bf16 conversion, filter some branch
inline __HOST_DEVICE__ bhalf_t inlineasm_float2bfloat16_nonan(const float f) {
bhalf_t ret;
#if defined(__gfx938__)
// ret.data = __builtin_hcu_cvt_bf16_f32(f, /*clamp*/false, /*dst_sel*/false);
*(unsigned short*)&ret = __builtin_hcu_cvt_bf16_f32(f, /*clamp*/false, /*dst_sel*/false);
// #elif __HIP_DEVICE_COMPILE__
// inline asm may lead to spill in scratch memory
#elif 0
unsigned int help = 0x7fff; // this line can be optimized, so as to use v_add3_u32
unsigned int tmp;
asm volatile(
"v_lshrrev_b32 %0, 16, %1\n\t"
"v_and_b32 %0, 0x1, %0\n\t"
: "=v"(tmp)
: "v"(f)
:);
asm volatile(
"v_add3_u32 %0, %2, %3, %4\n"
"v_lshrrev_b32 %1, 16, %0\n"
: "=v"(tmp), "=v"(ret.data)
: "v"(tmp), "s"(help), "v"(f)
:);
#else
// for inf, 0x7f80-0000 + 0x0000-7fff + (0x7f80 & 1) = 0x7f80-7ffff, and >> 16 -> 0x7f80, still inf
// for nan, no process, for input is from softmax, > 0 and no nan
// for others, actually, not totally rounding to nearest even, no case of mantissa 1000
union {
float fp32;
unsigned int u32;
} u = {f};
u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
*(unsigned short*)&ret = (u.u32 >> 16);
#endif
return ret;
}
// this seems to have no provement while writing global memory
inline __HOST_DEVICE__ unsigned short inlineasm_float2bfloat16_ushort_nonan(const float f) {
bhalf_t ret = inlineasm_float2bfloat16_nonan(f);
return *(unsigned short*)&ret;
}
// d = a * b + c
inline __device__ __float2 inlineasm_fa_v_pk_fma_f32(__float2 &a, const __float2& b, const __float2& c) {
__float2 d;
asm volatile("v_pk_fma_f32 %0, %1, %2, %3 ; inlineasm_fa_v_pk_fma_f32"
: "=v"(d)
: "v"(a), "v"(b), "v"(c)
:);
return d;
}
inline __device__ void inlineasm_fa_v_mov_b64(__float2 &c, const __float2 &a) {
asm volatile("v_mov_b64 %0, %1 ; inlineasm_fa_v_mov_b64"
: "=v"(c)
: "v"(a)
:);
}
extern __device__ __attribute__((const)) __float2 __llvm_v_pk_fma_f32(__float2, __float2, __float2) __asm("llvm.fma.v2f32");
inline __device__ void inlineasm_fa_v_pk_mul_f32(__float2 &dst, const __float2 &src, const __float2 &scale) {
asm volatile("v_pk_mul_f32 %0, %1, %2 ; inlineasm_fa_v_pk_mul_f32"
: "=v"(dst)
: "v"(src), "v"(scale)
:);
}
// c = a + b
inline __device__ void inline_v_pk_add_f32(__float2 &c, const __float2 &a, const __float2& b) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm volatile("v_pk_add_f32 %0, %1, %2 ; inline_v_pk_add_f32"
: "=v"(c)
: "v"(a), "v"(b)
:);
#else
c[0] = a[0] + b[0];
c[1] = a[1] + b[1];
#endif
}
/*
原来的 exp2f 对于极小数有特殊处理, 对于小于 -126 的输入 x , exp2f 计算方式是 2^(x + 64) * 2^{-64}
但是对于深度学习来说, 2^-126 的数字其实没那么重要了, 因此只需要保留 v_exp_f32 直接暴力计算即可
*/
extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32");
extern __device__ __attribute__((const)) float __llvm_log2_f32(float) __asm("llvm.log2.f32");
extern __device__ __attribute__((const)) float __llvm_fma_f32(float, float, float) __asm("llvm.fma.f32");
extern __device__ __attribute__((const)) int64_t __builtin_hcu_mov_b64(int64_t) __asm("llvm.hcu.mov64");
/* 直接内联汇编单独测试没问题, 但放到 flash attention 里面结果不对, 很奇怪 */
inline __device__ float inlineasm_fa_v_exp2_f32(const float x) {
// return exp2f(x);
float y;
asm volatile(
// "s_waitcnt lgkmcnt(0)\n\t"
"v_exp_f32 %0, %1"
: "=v"(y)
: "v"(x)
:);
return y;
}
#if !defined(__NVCC__)
// fp8_e5m2
constexpr int32_t e5m2_exp_bits = 5;
constexpr int32_t e5m2_mant_bits = 2;
constexpr int32_t e5m2_bits = 8;
constexpr int32_t e5m2_bias = (1 << (e5m2_exp_bits - 1)) - 1;
// fp8_e4m3
constexpr int32_t e4m3_exp_bits = 4;
constexpr int32_t e4m3_mant_bits = 3;
constexpr int32_t e4m3_bits = 8;
constexpr int32_t e4m3_bias = (1 << (e4m3_exp_bits - 1)) - 1;
// fp16
constexpr int32_t fp16_exp_bits = 5;
constexpr int32_t fp16_mant_bits = 10;
constexpr int32_t fp16_bits = 16;
constexpr int32_t fp16_bias = (1 << (fp16_exp_bits - 1)) - 1;
// fp32
constexpr int32_t fp32_exp_bits = 8;
constexpr int32_t fp32_mant_bits = 23;
constexpr int32_t fp32_bits = 32;
constexpr int32_t fp32_bias = (1 << (fp32_exp_bits - 1)) - 1;
__host__ __device__
static uint8_t __float2e4m3(const float src) {
// conversion from float to unsigned int(32 bits) for convience to fetching each bit
uint32_t __src = *(unsigned int*)&src;
// fetch sign bits
uint8_t sign_bits = (__src & 0x80000000) >> (fp32_bits - e5m2_bits);
// fetch exponent bitss
uint32_t exp_bits = __src & 0x7f800000;
// fetch mantissa bits
uint32_t mant_bits = __src & 0x007fffff;
// fetch absolute value
uint32_t data_scale = __src & 0x7fffffff;
// categorical discussions
/* NAN */
uint8_t result = 0x0;
if (exp_bits == 0x7f800000 and mant_bits > 0x0) {
// result = sign_bits | 0x7f; // output qNAN
result = 0x7f; // for NV's __nv_cvt_float_to_fp8:cvt.rn.satfinite.e4m3x2.f32, output are all 0x7f
}
/* inf or greater than MAX value of E5M2 */
else if ((exp_bits == 0x7f800000 and mant_bits == 0x0) or (data_scale > 0x43e00000)) {
result = sign_bits | 0x7e; // output MAX
}
/* less than MIN of denorm */
else if (exp_bits <= 0x3a800000) {
result = (exp_bits == 0x3a800000 and mant_bits > 0x0) ? sign_bits | 0x1: sign_bits;
}
/* others */
else {
/* norm fp32 can be represented as denorm fp8 / norm */
mant_bits = exp_bits < 0x3c800000 ? (0x800000 | mant_bits) >> ((0x3c800000 - exp_bits) >> fp32_mant_bits): mant_bits;
exp_bits = exp_bits < 0x3c800000 ? 0x0: ((exp_bits >> fp32_mant_bits) - (fp32_bias - e4m3_bias)) << e4m3_mant_bits;
// get discard bits
uint32_t discard_bits = mant_bits & 0xfffff;
// rounding
bool carry_a_bit = discard_bits > 0x80000 or (discard_bits == 0x80000 and (mant_bits & 0x100000));
mant_bits = (mant_bits & 0x700000) >> (fp32_mant_bits - e4m3_mant_bits);
mant_bits = carry_a_bit ? mant_bits + 1: mant_bits;
result = sign_bits + exp_bits + mant_bits; // + rather than |, since mant may carry a bit to exp
}
return result;
}
__host__ __device__
static float __e4m32float(const uint8_t src) {
// initialize ret value
float result;
// conversion from float to unsigned int(32 bits) for convience to fetching each bit
uint8_t __src = *(uint8_t*)&src;
// fetch sign bits
uint32_t sign_bits = __src & 0x80;
// fetch exponent bits
uint32_t exp_bits = (__src & 0x78) >> e4m3_mant_bits;
// fetch mantissa bits
uint32_t mant_bits = __src & 0x7;
// denorm or 0
if (exp_bits == 0x0 and mant_bits >= 0x0) {
result = 0.0078125f * ((mant_bits & 0x4) >> 2) + 0.00390625f * ((mant_bits & 0x2) >> 1) + 0.001953125f * (mant_bits & 0x1);
result = sign_bits > 0 ? -result: result;
} else {
uint32_t tmp = (exp_bits == 0xf and mant_bits == 0x7)
? /*input NaN*/0x7fffffff
: /*input norm*/(sign_bits << (fp32_bits - e4m3_bits)) + ((exp_bits - e4m3_bias + fp32_bias) << fp32_mant_bits) + (mant_bits << (fp32_mant_bits - e4m3_mant_bits));
result = *(float*)&tmp;
}
return result;
}
#endif // end of #if !defined(__NVCC__)
////////////////////////////////////////////////////////////////////////////////////////////////////
//数据类型转化封装
//DownCast
//fp32转fp16
template<class FromType, class ToType, bool Is_short = false, typename std::enable_if<std::is_same<FromType, float>::value && std::is_same<ToType,half_t>::value, int>::type = 0>
__host__ __device__ ToType DownCast(const FromType &from_var) {
return __float2half(from_var);
}
//fp32转bf16,并返回其内置数据类型
template<class FromType, class ToType, bool Is_short = false, typename std::enable_if<std::is_same<FromType, float>::value && Is_short && std::is_same<ToType, BFloat16>::value, int>::type = 0>
__host__ __device__ unsigned short DownCast(const FromType &from_var) {
#if defined(__gfx928__) || defined(__gfx936__)
return inlineasm_float2bfloat16_ushort_nonan(from_var);
#else
bhalf_t ret = __float2bfloat16(from_var);
return *(unsigned short*)&ret;
#endif
}
//fp32转bf16,返回其结构体本身
template<class FromType, class ToType, bool Is_short = false, typename std::enable_if<std::is_same<FromType, float>::value && !Is_short && std::is_same<ToType, BFloat16>::value, int>::type = 0>
__host__ __device__ BFloat16 DownCast(const float &from_var) {
#if 1
return inlineasm_float2bfloat16_nonan(from_var);
#else
return __float2bfloat16(from_var);
#endif
}
//fp32转fp8,返回其内置数据类型
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<std::is_same<FromType, float>::value && Is_uint8 && std::is_same<ToType, Float8_e4m3_t>::value, int>::type = 0>
__host__ __device__ uint8_t DownCast(const float &from_var) {
return __float2e4m3(from_var);
}
//fp32转fp8,返回其结构体本身
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<std::is_same<FromType, float>::value && !Is_uint8 && std::is_same<ToType, Float8_e4m3_t>::value, int>::type = 0>
__host__ __device__ Float8_e4m3_t DownCast(const float &from_var) {
return Float8_e4m3_t(__float2e4m3(from_var));
}
//fp16转fp8,返回其内置数据类型
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<std::is_same<FromType,half_t>::value && Is_uint8 && std::is_same<ToType, Float8_e4m3_t>::value, int>::type = 0>
__host__ __device__ uint8_t DownCast(const half_t &from_var) {
float src_f32 = __half2float(from_var);
return __float2e4m3(src_f32);
}
//fp16转fp8,返回其结构体本身
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<std::is_same<FromType,half_t>::value && !Is_uint8 && std::is_same<ToType, Float8_e4m3_t>::value, int>::type = 0>
__host__ __device__ Float8_e4m3_t DownCast(const half_t &from_var) {
float src_f32 = __half2float(from_var);
return Float8_e4m3_t(__float2e4m3(src_f32));
}
//fp32转fp16
template<class FromType, class ToType, bool Is_short = false, typename std::enable_if<std::is_same<FromType, float>::value && std::is_same<ToType, float>::value, int>::type = 0>
__host__ __device__ ToType DownCast(const FromType &from_var) {
return from_var;
}
//UpCast
//fp16转fp32
template<class FromType=half_t, class ToType=float, bool Is_short = false, typename std::enable_if<std::is_same<FromType,half_t>::value && std::is_same<ToType, float>::value, int>::type = 0>
__host__ __device__ float UpCast(const half_t &from_var) {
return __half2float(from_var);
}
//bf16的内置数据类型转fp32
template<class FromType, class ToType, bool Is_short = false, typename std::enable_if<Is_short && std::is_same<FromType, BFloat16>::value && std::is_same<ToType, float>::value, int>::type = 0>
__host__ __device__ float UpCast(const unsigned short &from_var) {
bhalf_t x;
*(unsigned short*)&x = from_var;
return __bfloat162float(x);
}
//bf16转fp32
template<class FromType=bhalf_t, class ToType=float, bool Is_short = false,typename std::enable_if<!Is_short && std::is_same<FromType, BFloat16>::value && std::is_same<ToType,float>::value, int>::type = 0>
__host__ __device__ float UpCast(const BFloat16 &from_var) {
return __bfloat162float(from_var);
}
//fp8的内置数据类型转fp32
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<Is_uint8 && std::is_same<FromType, Float8_e4m3_t>::value && std::is_same<ToType, float>::value, int>::type = 0>
__host__ __device__ float UpCast(const uint8_t &from_var) {
return __e4m32float(from_var);
}
//fp8转fp32
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<!Is_uint8 && std::is_same<FromType, Float8_e4m3_t>::value && std::is_same<ToType, float>::value, int>::type = 0>
__host__ __device__ float UpCast(const Float8_e4m3_t &from_var) {
return __e4m32float(from_var.data);
}
//fp8的内置数据类型转fp16
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<Is_uint8 && std::is_same<FromType, Float8_e4m3_t>::value && std::is_same<ToType,half_t>::value, int>::type = 0>
__host__ __device__ half_t UpCast(const uint8_t &from_var) {
float src_f32 = __e4m32float(from_var);
return __float2half(src_f32);
}
//fp8转fp16
template<class FromType, class ToType, bool Is_uint8 = false, typename std::enable_if<!Is_uint8 && std::is_same<FromType, Float8_e4m3_t>::value && std::is_same<ToType,half_t>::value, int>::type = 0>
__host__ __device__ half_t UpCast(const Float8_e4m3_t &from_var) {
float src_f32 = __e4m32float(from_var.data);
return __float2half(src_f32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class FromType, class ToType>
inline __host__ __device__ auto DownCastPair(const vec2_Element<FromType>& source) {
static_assert(false and "No Cvt method for DownCastPair!");
return vec2_Element<ToType>(0);
}
template<>
inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = __builtin_hcu_cvt_pk_f16_f32(source[0], source[1], false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result);
#else
return __builtin_amdgcn_cvt_pkrtz(source[0], source[1]);
#endif
}
template<>
inline __host__ __device__ auto DownCastPair<float, bhalf_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = __builtin_hcu_cvt_pk_bf16_f32(source[0], source[1], false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result);
#else
vec2_Element<bhalf_t> result;
result[0] = inlineasm_float2bfloat16_ushort_nonan(source[0]);
result[1] = inlineasm_float2bfloat16_ushort_nonan(source[1]);
return result;
#endif
}
// Support when src0 and src1 are not contiguously rearranged
template<class FromType, class ToType>
inline __host__ __device__ auto DownCastPairNoPack(const FromType src0, const FromType src1) {
static_assert(false and "No Cvt method for DownCastPairNoPack!");
return vec2_Element<ToType>(0);
}
template<>
inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float src0, const float src1) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = __builtin_hcu_cvt_pk_f16_f32(src0, src1, false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result);
#else
return __builtin_amdgcn_cvt_pkrtz(src0, src1);
#endif
}
template<>
inline __host__ __device__ auto DownCastPairNoPack<float, bhalf_t>(const float src0, const float src1) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = __builtin_hcu_cvt_pk_bf16_f32(src0, src1, false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result);
#else
vec2_Element<bhalf_t> result;
result[0] = inlineasm_float2bfloat16_ushort_nonan(src0);
result[1] = inlineasm_float2bfloat16_ushort_nonan(src1);
return result;
#endif
}
template<>
inline __host__ __device__ auto DownCastPairNoPack<float, float>(const float src0, const float src1) {
__float2 result;
result[0] = src0;
result[1] = src1;
return *(vec2_Element<float>*)(&result);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// distinct upcast function to avoid regression, used in splitkv
template<typename accumType, class FromType>
__host__ __device__ float splitkv_upcast_to_f32(const FromType &from_var) {
if constexpr (std::is_same<FromType, half_t>::value or std::is_same<FromType, __half>::value) {
return __half2float(from_var);
} else if constexpr (std::is_same<FromType, __hip_bfloat16>::value) {
return __bfloat162float(from_var);
} else if constexpr (std::is_same<FromType, unsigned short>::value) {
bhalf_t container;
*(unsigned short*)&container = from_var;
return __bfloat162float(container);
} else {
return from_var;
}
}
template<typename output_dtype>
__forceinline__ __device__ void __builtin_hcu_cvt_pk4_fp8_f32(const vec4_fp32& source, int32_t &container) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (std::is_same<output_dtype, fp8_e4m3>::value) {
container = __builtin_hcu_cvt_pk_fp8_f32(source[0], source[1], container, false/*op_sel:[0,0,0,0]*/);
container = __builtin_hcu_cvt_pk_fp8_f32(source[2], source[3], container, true/*op_sel:[0,0,0,1]*/);
} else if constexpr (std::is_same<output_dtype, fp8_e5m2>::value) {
container = __builtin_hcu_cvt_pk_bf8_f32(source[0], source[1], container, false/*op_sel:[0,0,0,0]*/);
container = __builtin_hcu_cvt_pk_bf8_f32(source[2], source[3], container, true/*op_sel:[0,0,0,1]*/);
} else {
static_assert (false and "Inputs of invalid dtype fed to __builtin_hcu_cvt_pk4_fp8_f32");
}
#endif
}
#pragma once
#include <vector>
#include "numeric_types.h"
#include "intrinsic.h"
// ======================================================= MLS ===========================================================
#define VA_LIMIT_BITS(x) (0xffffffffffff & x)
#define MATRIX_LOAD_32X32_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
#define MATRIX_LOAD_32X32_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
// ======================================================= MLS32x16 ===========================================================
#define MATRIX_LOAD_32X16_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
#define MATRIX_LOAD_32X16_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
// ======================================================= DS ===========================================================
#define DS_READ_MATRIX_32X32_B16(OFFSET, REG, REG1, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_mov_b32 m0, %2\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_trans_format %1, m0 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %2\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_format %1, m0 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X32_B16_GFX946(OFFSET, REG, REG1, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_trans_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X16_B16(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X16_B16_ALT2(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x2 row:0x1 col:0x2 alt:0x1\n\t" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x1\n\t" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X32_B16_ALT2(OFFSET, REG, REG1, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_mov_b32 m0, %2\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x2 row:0x1 col:0x2 alt:0x1\n\t" \
"ds_read_matrix_trans_format %1, m0 offset:1024 element:0x2 row:0x1 col:0x2 alt:0x1\n\t" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %2\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x2 row:0x2 col:0x1 alt:0x1\n\t" \
"ds_read_matrix_format %1, m0 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x1\n\t" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
}
template<int min_value, int max_value>
__forceinline__ __device__ int inline_min_max(int source) {
/*
To avoid usage of v_med3_i32
----> to avoid usage of __builtin_amdgcn_readfirstlane
----> to avoid usage of 5 nops for mls data hazard
*/
return max(min_value, min(max_value, source));
// int result;
// asm volatile("s_max_i32 %0, %1, %2\n\t"
// "s_min_i32 %0, %0, %3\n"
// : "=s"(result)
// : "s"(source), "n"(min_value), "n"(max_value)
// :);
// return result;
}
// ======================================================= def ===========================================================
#define YY_USE_MPERMUTE
template<typename VEC>
__forceinline__ __device__ void ds_mpermute_kdim_for_mmac(VEC& data) {
asm volatile("ds_mpermute_dwordx2 %0, %0 offset:6\n":: "v"(data));
}
template<typename VEC>
__forceinline__ __device__ void ds_mpermute_kdim_for_mmac_wait(VEC& data) {
asm volatile("ds_mpermute_dwordx2 %0, %0 offset:6\n\ts_waitcnt lgkmcnt(0)":: "v"(data));
}
// ======================================================= mmac ===========================================================
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac_4interleave(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
}
template<>
inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx938__) || defined(__gfx946__)
return __builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(v1, v2, v3, 1, 0);
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac_4interleave<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx938__) || defined(__gfx946__)
return __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(v1, v2, v3, 1, 0);
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
#pragma once
#define MATRIX_LOAD_128X16_B8_LDS_TRANS(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 4\n\t" \
"matrix_load_128x16_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(MATRIX_OFFSET) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif
}
#define DS_READ_MATRIX_64x16_B8(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_add_u32 m0, %1, 0x80000000\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x1 row:0x3 col:0x1 alt:0x0\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x1 row:0x3 col:0x1 alt:0x0\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
}
#define MATRIX_LOAD_64x32_B8_LDS_REARRANGE(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
asm volatile("s_nop 4\n\t" \
"matrix_load_64x32_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(MATRIX_OFFSET) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_64x32_b8_lds_rearrange(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif
}
#define DS_READ_MATRIX_32x32_B8(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_add_u32 m0, %1, 0x80000000\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x1 row:0x2 col:0x2 alt:0x0\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x1 row:0x2 col:0x2 alt:0x0\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32x32_B8_ALT2(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_add_u32 m0, %1, 0x80000000\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, m0 offset:0 element:0x1 row:0x2 col:0x2 alt:0x1\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_mov_b32 m0, %1\n\t" \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, m0 offset:0 element:0x1 row:0x2 col:0x2 alt:0x1\n" \
: "=v"(REG) \
: "s"(OFFSET) \
:); \
}
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac_4interleave_b8(const vec8_Element<T> &v1, const vec8_Element<T> &v2, const vec4_fp32 &v3)
{
return __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(v1, v2, v3, 1, 0);
}
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
// #include "cute/algorithm/copy.hpp"
// #include "cutlass/cutlass.h"
// #include "cutlass/layout/layout.h"
#include "numeric_types.h"
// using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kBlockK_, int kWaveM_, int kWaveN_, typename elem_type=Float16>
struct Flash_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = uint32_t;
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kHeadDimV_, int kBlockM_, int kBlockN_, int kBlockK_, int kWaveM_, int kWaveN_, int STAGES_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=Float16, typename splitkv_accum_dtype=Float16, typename elem_type_k=Float16, int kBlockK_int8_=64,
int kHeadDimQKCompute_=kHeadDim_, int kHeadDimPVCompute_=kHeadDimV_, int TailTile16_=2,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kBlockK_, kWaveM_, kWaveN_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using Element_k = elem_type_k;
using index_t = typename Base::index_t;
using SplitkvAccumType = splitkv_accum_dtype;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kBlockK_int8 = kBlockK_int8_;
static constexpr int kWaveM = kWaveM_;
static constexpr int kWaveN = kWaveN_;
static constexpr int STAGES = STAGES_;
// The number of threads.
static constexpr int kNWarps = kBlockM_ / kWaveM_;
static constexpr int kNThreads = kNWarps * 64;
static constexpr int kHeadDim = kHeadDim_;
static constexpr int kHeadDimV = kHeadDimV_;
static constexpr int kHeadDimQKCompute = kHeadDimQKCompute_;
static constexpr int kHeadDimPVCompute = kHeadDimPVCompute_;
static constexpr int TailTile16 = TailTile16_;
static constexpr int SplitD = (kHeadDimV <= 512) ? 1: kHeadDimV / 128;
static constexpr int kHeadDimVSplit = kHeadDimV / SplitD;
static_assert(kHeadDim % 32 == 0);
static_assert(kHeadDimV % 32 == 0);
static constexpr int kSmemQCount = 1;
static constexpr int kSmemKVCount = 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr size_t q_smem_size = (STAGES * (kBlockM / 32) * (kBlockK / 32) * (32 * 34)) * sizeof(Element);
static constexpr size_t k_smem_size = (STAGES * (kWaveN / 32) * (kBlockK / 32) * (32 * 34)) * sizeof(Element);
static constexpr size_t v_smem_size = (STAGES * kBlockK * 32/*WARP_K*/) * sizeof(Element);
#if (TARGET == 928)
static constexpr int kSmemSize = std::max(q_smem_size, v_smem_size) + k_smem_size * 2;
#else
static constexpr int kSmemSize = std::max(std::max(q_smem_size, v_smem_size), k_smem_size * 2);
#endif
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
// No_double_buffer is another option to reduce smem usage, but will slow things down.
template<int kHeadDim_, int kHeadDimV_, int kBlockM_, int kBlockN_, int kBlockK_, int kWaveM_, int kWaveN_,
int STAGES_, bool Is_V_in_regs_=false, typename elem_type=Float16,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kBlockK_, kWaveM_, kWaveN_, elem_type> >
struct Flash_bwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
// static constexpr bool Has_cp_async = Base::Has_cp_async;
// using SmemCopyAtom = typename Base::SmemCopyAtom;
// using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
// The number of threads.
static constexpr int kWaveM = kWaveM_;
static constexpr int kWaveN = kWaveN_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kHeadDim = kHeadDim_;
static constexpr int kHeadDimV = kHeadDimV_;
static constexpr int STAGES = STAGES_;
static constexpr int q_smem_size = (STAGES*(kBlockM/32) * (kBlockK/32)*(32*34)) * sizeof(elem_type);
static constexpr int k_smem_size = (STAGES*(kBlockN/32) * (kBlockK/32)*(32*34)) * sizeof(elem_type);
static constexpr int v_smem_size = (STAGES*kBlockK * kBlockN) * sizeof(elem_type);
static constexpr int kSmemSize1colblock = max((q_smem_size + k_smem_size), v_smem_size);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "numeric_types.h"
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Interleave2, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_epilogue_store_output_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * params.b * params.o_batch_stride
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1);
if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return;
}
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
int k_loop_stride = OPT_FOR_HDIM128 ? WARP_NUM: 1;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += k_loop_stride) {
int tile_32x32_id = OPT_FOR_HDIM128 ? 0: k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// seqlen_q offset
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + pv_lane_seq_idx + min_tile_m * 16;
if (seqlen_q_idx < params.seqlen_q) {
if constexpr (Interleave2) {
/*contiguous 64 bytes storation*/
union_vec4_f16x2<SplitkvAccumType> v_data;
v_data.f16x2[0 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[0], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[0]);
v_data.f16x2[1 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[1], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[1]);
v_data.f16x2[0 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[2], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[2]);
v_data.f16x2[1 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[3], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[3]);
int pv_global_addr = seqlen_q_idx * output_seqlen_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 8;
*(vec4_fp32*)(o_ptr + pv_global_addr) = v_data.f32;
} else {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
union_vec2_f16x2<SplitkvAccumType> data;
int mmac_id = min_tile_m + min_tile_n * 2;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
data.f16x2[vec_index] = DownCastPair<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][mmac_id].f32x2[vec_index]);
}
int pv_global_addr = seqlen_q_idx * output_seqlen_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 4 + min_tile_n * 16;
*(union_vec2_f16x2<SplitkvAccumType>*)(o_ptr + pv_global_addr) = data;
}
}
}
}
}
}
}
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int64_t row_offset_o,
int seqlen_q_limit,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_split = params.ngroups * int64_t(params.total_q) * params.o_row_stride;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * row_offset_split
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
auto gO = prepare_for_buffer_load<kHeadDimV, SplitkvAccumType, false/*USE_CACHE_SWIZZLE*/>(o_ptr);
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1);
if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return;
}
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
int k_loop_stride = OPT_FOR_HDIM128 ? WARP_NUM: 1;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += k_loop_stride) {
int tile_32x32_id = OPT_FOR_HDIM128 ? 0: k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// seqlen_q offset
int seqlen_q_idx = warp_m_idx * 32 + pv_lane_seq_idx + min_tile_m * 16;
if (seqlen_q_idx < seqlen_q_limit) {
union_vec4_f16x2<SplitkvAccumType> v_data;
v_data.f16x2[0 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[0], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[0]);
v_data.f16x2[1 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[1], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[1]);
v_data.f16x2[0 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[2], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[2]);
v_data.f16x2[1 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[3], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[3]);
// int pv_global_addr = seqlen_q_idx * output_seqlen_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 8;
int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups;
int pv_global_addr = true_seqlen_q * params.ngroups * output_seqlen_stride + true_group_id * params.o_head_stride + (k_loop + warp_id) * kBlockK + k_tile_idx * 32 + pv_lane_head_dim_idx * 8;
*(vec4_fp32*)(o_ptr + pv_global_addr) = v_data.f32;
}
}
}
}
}
}
\ No newline at end of file
#pragma once
#include "kvcache_pv_gemm_utils_gfx938.h"
template<int K_LOOP_COUNT, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=0) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32");
static_assert (M_WARP_COUNT == 1, "for gfx938, only WARP_M = 32 is supported yet!");
static_assert (PV_N_WARP_COUNT == 1, "for gfx938, only WARP_N = 32 is supported yet!");
static_assert (PV_K_WARP_COUNT == 1, "for gfx938, only WARP_K = 32 is supported yet!");
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
// 准备寄存器, 每次加载 32x32 的 half 用于 mmac 计算, 每个线程持有 16 个 half, 因此是 8 * 2, 一列有 8 个 half, 有两列
union_vec4_f16x2<Element> v_reg[1 * PV_K_WARP_COUNT * PV_N_WARP_COUNT * 2];
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = kvcache_seqlen_stride; // stride
// 防止与多 wave reduce max 需要的 lds 冲突
__syncthreads();
int stage_id = (STAGES == 2) ? 1: 0;
// 一次加载多批数据
constexpr int N_LOOP_STEP = (STAGES == 2) ? 2: 1;
constexpr int N_LOOP_START = (STAGES == 2) ? K_LOOP_COUNT - N_LOOP_STEP * 2: K_LOOP_COUNT - 1;
constexpr int N_LOOP_END = 0;
for (int n_loop = N_LOOP_START; n_loop >= N_LOOP_END; n_loop -= N_LOOP_STEP) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) {
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int v_mls_warp_global_offset = (n_loop + prefetch_id) * kBlockN * sizeof(Element);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int v_mls_lds_warp_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32) * sizeof(Element);
// 计算当前 wave 读取数据的起始偏移字节数
int v_mls_loop_global_offset; // = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset = real_mls_warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
// 等待 MLS 数据回来
if constexpr (N_LOOP_STEP == 2) {
buffer_load_lds_dwordx1_wait_nosync<3 * V_LOAD_REQUESTS>();
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// 切换到 load 轮次
if constexpr (STAGES == 2) {
stage_id ^= 1;
}
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false); // hint: multiple prefetching can be applied here
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = (STAGES == 2) ? n_loop + 2: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
// ============================================================================================================
// 处理预取的第二段数据
if constexpr (N_LOOP_STEP == 2) {
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
__builtin_amdgcn_sched_barrier(0);
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = (STAGES == 2) ? n_loop + 3: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
}
}
if constexpr (STAGES == 2) {
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<1 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int n_loop = N_LOOP_END - N_LOOP_STEP;
// 切换
stage_id ^= 1;
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = n_loop + N_LOOP_STEP;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
// ============================================================================================================
// 处理预取的第二段数据
if constexpr (N_LOOP_STEP == 2) {
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
int lds_load_offset = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * 2/*bytes*/;
__builtin_amdgcn_sched_barrier(0);
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[0].f16, v_reg[1].f16, false);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)" :: "B"(2 - min_tile_k - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
int pv_tile_id = n_loop + N_LOOP_STEP + 1;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
asm volatile("s_setprio 0");
}
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void kvcache_prefetch_v_to_lds_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=0) {
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
constexpr int N_LOOP_STEP = 2;
// 准备 MLS 的 resource 寄存器
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = kvcache_seqlen_stride; // stride
// 从倒数第 2 个 block 开始读取
int n_loop = kHeadDim / kBlockN - N_LOOP_STEP;
#pragma unroll
for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) {
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int v_mls_warp_global_offset = (n_loop + prefetch_id) * kBlockN * sizeof(Element);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int v_mls_lds_warp_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32) * sizeof(Element);
// 计算当前 wave 读取数据的起始偏移字节数
int v_mls_loop_global_offset;// = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset = real_mls_warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
#pragma once
#include "kvcache_qk_gemm_utils_gfx938.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset=0) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
constexpr int K_LOAD_REQUESTS = (WARP_N / 32) * (kBlockK / 32);
// 分配 k 计算 mmac 需要的寄存器资源
// 一次加载 32x32 个 half, 每个线程持有 16 个 half
union_vec4_f16x2<Element> k_reg[1 * (WARP_N * kBlockK) / (32 * 32) * 2];
// 初始化 s
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
// 准备 MLS resource 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = kcache_seqlen_stride;
int stage_id = 0;
constexpr int K_LOOP_START = (STAGES == 2) ? 2: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop += 2) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < 2; ++prefetch_id) {
// 计算当前 wave 写到 lds 的起始地址
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * K_LOAD_REQUESTS * (32 * 32);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int k_mls_loop_global_offset = (k_loop + prefetch_id) * kBlockK * sizeof(Element);
// 计算当前 wave 从 global 读取数据的起始地址
int k_mls_warp_global_offset; // = warp_id * WARP_N * kcache_seqlen_stride * sizeof(Element);
if constexpr (true) {
int nm_filter_max = warp_id * WARP_N + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset = real_mls_warp_id * WARP_N * kcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_N + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
}
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<3 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) stage_id ^= 1;
// 加载上一次 MLS 写到 lds 的数据到寄存器
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 2: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
// =============================================================================================================
{
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
// 加载上一次 MLS 写到 lds 的数据到寄存器
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (STAGES == 2) {
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<1>();
__builtin_amdgcn_sched_barrier(0);
// 切换到上一次 lds 被写入的轮次
stage_id ^= 1;
// 从 lds 加载最后一部分数据
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 2;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
// ==========================================================================
{
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0>();
__builtin_amdgcn_sched_barrier(0);
// 从 lds 加载最后一部分数据
int lds_load_offset = reinterpret_cast<size_t>(k_lds) + (warp_id * STAGES * 2 + stage_id * 2 + 1) * K_LOAD_REQUESTS * (32 * 32) * sizeof(Element)/*half -> bytes*/;
DS_READ_MATRIX_32X32_B16(lds_load_offset, k_reg[0].f16, k_reg[1].f16, true);
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(%0)\n" :: "B"(2 - min_tile_n - 1));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = kHeadDim / kBlockK - 1;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
kvcache_prefetch_v_to_lds_gfx938<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(v_addr, v_lds, warp_id, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "kvcache_pv_gemm_utils_gfx938.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_gfx938(
vec4_uint q_addr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset=0) {
if constexpr (kHeadDim == 128 and WARP_NUM == 4) {
// 准备 MLS 寄存器
vec4_uint q_srsrc;
q_srsrc[1] = q_addr[1];
q_srsrc[2] = query_seqlen_stride;
// kHeadDim 方向上的第几个 32x32 块
int q_loop = 0;
// 计算当前 wave 写到 lds 的起始地址
int k_lds_stage_offset = warp_id * (WARP_M / 32) * (kBlockK / 32) * (32 * 32);
// 计算当前 wave 从 global 读取数据的起始地址
int k_mls_warp_global_offset = warp_id * kBlockK;
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int k_mls_loop_global_offset = q_loop * kBlockK;
// 根据偏移计算 global load 的字节偏移数
q_srsrc[0] = q_addr[0] + (k_mls_loop_global_offset + k_mls_warp_global_offset ) * 2;
if constexpr (true) {
int nm_filter = inline_min_max<0, 32>(32 - max_seq_q_offset);
q_srsrc[3] = max_seq_q_offset % 32 == 0 ? 0: nm_filter << 8; // set only once
}
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
flash::wait_lds_data_arrived<true>(0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(q_lds, q_srsrc, lds_offset_bytes, 0);
flash::wait_buffer_data_arrived<true>(0);
// 开始读取数据
__builtin_amdgcn_sched_barrier(0);
// 注意 M_MMAC_COUNT = 1 的时候只需要读一次
if constexpr (M_MMAC_COUNT == 1) {
DS_READ_MATRIX_32X16_B16(0 * 32 * 32 * 2, q_reg[0 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(1 * 32 * 32 * 2, q_reg[1 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(2 * 32 * 32 * 2, q_reg[2 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(3 * 32 * 32 * 2, q_reg[3 * 2].f16, true);
} else {
DS_READ_MATRIX_32X32_B16(0 * 32 * 32 * 2, q_reg[0 * 2].f16, q_reg[0 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(1 * 32 * 32 * 2, q_reg[1 * 2].f16, q_reg[1 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(2 * 32 * 32 * 2, q_reg[2 * 2].f16, q_reg[2 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(3 * 32 * 32 * 2, q_reg[3 * 2].f16, q_reg[3 * 2 + 1].f16, true);
}
flash::wait_lds_data_arrived<true>(0);
}
else {
constexpr int Q_LOAD_REQUESTS = (kBlockM * kBlockK >> 1/*16x32 tile*/) * M_MMAC_COUNT / (4 * 32 * WARP_NUM);
constexpr int SEQUENCE_READ = M_MMAC_COUNT;
constexpr int READ_ONCE_LINES = 4;
auto BUFFER_LOAD_FUNC = &builtin_buffer_load_dword_lds<Element, float, 1>; // buffer_load_dwordx4 can also be applied if necessary
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = lane_id >> 4;
int q_lane_head_dim_idx = lane_id & 15;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int q_ds_read_offset = laneid_and_15 * 16 + laneid_shfl_4 * 2;
int stage_id = 0;
if constexpr (STAGES > 1) {
int k_loop = 0;
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for (int j = 0; j < 2; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset + j * 8;
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
q_reg[k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].u64[j] = *(__float2*)(q_lds_v2fp16 + lds_offset);
}
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for (int j = 0; j < 2; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset + j * 8;
q_reg[((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].u64[j] = *(__float2*)(q_lds_v2fp16 + lds_offset);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
}
}
template<int kBlockK, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void kvcache_prefetch_k_to_lds_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=0) {
// 准备 MLS 寄存器
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = kvcache_seqlen_stride;
// pingpong buffer 的第一阶段
int stage_id = 0;
// kHeadDim 方向上的第几个 32x32 块
int k_loop = 0;
#pragma unroll
for (int prefetch_id = 0; prefetch_id < 2; ++prefetch_id) {
// 计算当前 wave 写到 lds 的起始地址
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + prefetch_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int k_mls_loop_global_offset = (k_loop + prefetch_id) * kBlockK * sizeof(Element);
// 计算当前 wave 从 global 读取数据的起始地址
int k_mls_warp_global_offset; // = warp_id * WARP_N * kvcache_seqlen_stride;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_N + 32 - max_seq_k_offset; // 判断是否有 warp 取空数据
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset = real_mls_warp_id * WARP_N * kvcache_seqlen_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_N + 32 - max_seq_k_offset); // 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc[3] = nm_filter << 8;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
#pragma once
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_gfx938(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_gfx938(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int ngroups) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + lane_id & 15;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_right = std::min(max_seqlen_k, (row_idx / ngroups)/*only for layout 1: bshd*/ + max_seqlen_k - (max_seqlen_q / ngroups));
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int mtp, const int layout) {
const int MTP_REGROUP_COUNT = max_seqlen_q / mtp;
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + lane_id & 15;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / MTP_REGROUP_COUNT);
const int col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int kHeadDim, int kBlockK, int WARP_M, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void int8_kvcache_acco_reduce(
vec4_Accum<ElementAccum> acc_o[(kHeadDim/kBlockK) * ((WARP_M/32)*(kBlockK/32))][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int WARP_ID,
int lane_id) {
// when REUSE_KV not in templated, compute max reuse times
int EVEN_REUSE_KV_TIMES = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2: ((seqlen_q + 1) / 2) * 2;
int HALF_REUSE_KV_TIMES = EVEN_REUSE_KV_TIMES >> 1;
int q_seq_idx = (lane_id & 15);
constexpr int __kHeadDim = (REUSE_KV_TIMES >= 16) ? kHeadDim: kHeadDim + 4/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
if (q_seq_idx < HALF_REUSE_KV_TIMES) {
// ####################################################################################################################################################
// 4 个 wave 分别把自己负责的 acc_o 计算结果写到 LDS 中
for(int h_idx=0; h_idx<(kHeadDim/kBlockK); h_idx++) {
for(int k_idx=0; k_idx<(kBlockK/32); k_idx++) {
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int lds_offset = WARP_ID*EVEN_REUSE_KV_TIMES*__kHeadDim + q_seq_idx*2*__kHeadDim + min_tile_m*__kHeadDim + h_idx*kBlockK + k_idx*32 + min_tile_k*16 + (lane_id>>4)*4;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx* ((WARP_M/32)*(kBlockK/32)) + k_idx*(WARP_M/32)][min_tile_k*2 + min_tile_m].f32;
}
}
}
}
__syncthreads();
// ####################################################################################################################################################
// 4 个 wave 共同参与 acc_o 在 LDS 中的相加
// 判断当前架构是否支持 pk_f32 指令
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
constexpr bool SUPPORT_PK_F32 = true;
#else
constexpr bool SUPPORT_PK_F32 = false;
#endif
// 当 gfx936 而且 EVEN_REUSE_KV_TIMES 编译期可知的情况下, 可以大胆使用 ds_read2_b32 指令; 且是 GQA 的情况下, 可以更好地使用 pk 指令, 直接内联汇编控制
if constexpr (SUPPORT_PK_F32 and REUSE_KV_TIMES > 0 and M_MMAC_COUNT > 1 and kHeadDim == 128) {
// static_assert (kBlockK == 32 && "only kBlockK=32 is supported!");
// static_assert (kHeadDim == 128 && "only kHeadDim=128 is supported!");
union_vec2_fp32 acc_tmp_wave0[4*2];
union_vec2_fp32 acc_tmp_wave1[2], acc_tmp_wave2[2], acc_tmp_wave3[2];
// 先预取第一次数据
int loop_id = 0;
int lds_offset[4*2];
lds_offset[0] = 0*__kHeadDim + q_seq_idx*2*__kHeadDim + 0*kBlockK + (lane_id>>4)*4 + WARP_ID;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id], acc_tmp_wave0[loop_id].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[0].u64, 0, 16);
__builtin_amdgcn_sched_barrier(0);
// 先计算 lds_offset
lds_offset[1] = 1*__kHeadDim + q_seq_idx*2*__kHeadDim + 0*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[2] = 0*__kHeadDim + q_seq_idx*2*__kHeadDim + 1*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[3] = 1*__kHeadDim + q_seq_idx*2*__kHeadDim + 1*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[4] = 0*__kHeadDim + q_seq_idx*2*__kHeadDim + 2*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[5] = 1*__kHeadDim + q_seq_idx*2*__kHeadDim + 2*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[6] = 0*__kHeadDim + q_seq_idx*2*__kHeadDim + 3*kBlockK + (lane_id>>4)*4 + WARP_ID;
lds_offset[7] = 1*__kHeadDim + q_seq_idx*2*__kHeadDim + 3*kBlockK + (lane_id>>4)*4 + WARP_ID;
// asm volatile("s_nop 8\n");
{
int loop_id = 0/*h_idx*2 + min_tile_m*/;
// 预取下一阶段的数据
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[1].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 1/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[0].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 2/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[1].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 3/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[0].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 4/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[1].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 5/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[0].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[0].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
int loop_id = 6/*h_idx*2 + min_tile_m*/;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1], acc_tmp_wave0[loop_id + 1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2[1].u64, 0, 16);
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset[loop_id + 1] + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3[1].u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// 先写一部分数据到 lds
for (int loop_id = 0; loop_id < 7; ++loop_id) {
acc_o_lds[lds_offset[loop_id]] = acc_tmp_wave0[loop_id].f32[0];
acc_o_lds[lds_offset[loop_id] + 16] = acc_tmp_wave0[loop_id].f32[1];
}
// 再等待最后一部分需要的数据回来, 计算和写最后的数据
{
int loop_id = 7/*h_idx*2 + min_tile_m*/;
asm volatile("s_waitcnt lgkmcnt(2)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(1)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
__builtin_amdgcn_sched_barrier(0);
acc_o_lds[lds_offset[loop_id]] = acc_tmp_wave0[loop_id].f32[0];
acc_o_lds[lds_offset[loop_id] + 16] = acc_tmp_wave0[loop_id].f32[1];
// inline_ds_write2_b32_no_wait_bytes(acc_o_lds, lds_offset[loop_id], acc_tmp_wave0[loop_id].u64, 0, 16);
}
// 代替 __syncthreads()
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_waitcnt lgkmcnt(0)\n\t"
"s_barrier");
__builtin_amdgcn_sched_barrier(0);
}
// 当 EVEN_REUSE_KV_TIMES 编译期可知的情况下, 可以大胆使用 ds_read2_b32 指令, gfx928 和 gfx936 都能用
else if constexpr (REUSE_KV_TIMES > 0) {
for(int h_idx=0; h_idx<(kHeadDim/kBlockK); h_idx++) {
for(int k_idx=0; k_idx<(kBlockK/32); k_idx++) {
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 0*16 + (lane_id>>4)*4 + WARP_ID;
int lds_offset1 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 1*16 + (lane_id>>4)*4 + WARP_ID;
acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
// acc_tmp.f32[0] = acc_o_lds[lds_offset0];
// acc_tmp.f32[1] = acc_o_lds[lds_offset1];
union_vec2_fp32 acc_tmp_wave1;
acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave1.f32[0] = acc_o_lds[lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave1.f32[1] = acc_o_lds[lds_offset1 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2;
acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave2.f32[0] = acc_o_lds[lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave2.f32[1] = acc_o_lds[lds_offset1 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3;
acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave3.f32[0] = acc_o_lds[lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave3.f32[1] = acc_o_lds[lds_offset1 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
acc_tmp.f32[1] += acc_tmp_wave3.f32[1];
// ds_write2_b32
acc_o_lds[lds_offset0] = acc_tmp.f32[0];
acc_o_lds[lds_offset1] = acc_tmp.f32[1];
}
}
}
__syncthreads();
} else {
// REUSE_KV_TIMES 编译期不可知, 导致 EVEN_REUSE_KV_TIMES 也编译期不可知, 无法直接调用 ds_read2_b32, 所以交给编译器去优化
for(int h_idx=0; h_idx<(kHeadDim/kBlockK); h_idx++) {
for(int k_idx=0; k_idx<(kBlockK/32); k_idx++) {
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int lds_offset = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + min_tile_k*16 + (lane_id>>4)*4 + WARP_ID;
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for(int loop=1; loop<4; loop++) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop*EVEN_REUSE_KV_TIMES*__kHeadDim];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
__syncthreads();
}
// ####################################################################################################################################################
// 每个 wave 都从 LDS 获取最终的求和结果
for(int h_idx=0; h_idx<(kHeadDim/kBlockK); h_idx++) {
for(int k_idx=0; k_idx<(kBlockK/32); k_idx++) {
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int lds_offset = q_seq_idx*2*__kHeadDim + min_tile_m*__kHeadDim + h_idx*kBlockK + k_idx*32 + min_tile_k*16 + (lane_id>>4)*4;
acc_o[h_idx* ((WARP_M/32)*(kBlockK/32)) + k_idx*(WARP_M/32)][min_tile_k*2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
}
}
}
\ No newline at end of file
#include "int8_kvcache_pv_gemm_prefetch_k_3stage.h"
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename Element_k, typename ElementAccum>
__forceinline__ __device__ void int8_kvcache_pv_gemm_prefetch_k(
vec4_uint gV,
vec4_uint gK,
Element_k* v_lds,
Element_k* k_lds,
float scales_v[2][4],
union_vec2_f16x2<Element> p_reg[(WARP_M/32)*(WARP_K/32)][4],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV/kBlockN)*(WARP_M/32)*(kBlockN/32)][4],
int WARP_ID,
int vcache_seqlen_stride,
int max_seq_kv_offset = -1) {
static_assert(kBlockK>=32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert(kBlockM>=WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert(kBlockN==WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
union_vec2_f16x2<Element> v_reg[STAGES*((32*WARP_N)/(32*32))][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
// constexpr int WARP_K = 32;
constexpr int READ_ONCE_LINES = 16; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个4线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element_k, 2>;
#else
// constexpr int WARP_K = 32;
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_k, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + WARP_ID * STAGES * WARP_K * kBlockN;
constexpr int N_LOOP_START = (STAGES == 2) ? 1: 0;
int stage_id = (STAGES == 2) ? 1: 0;
for (int n_loop = N_LOOP_START; n_loop < (kHeadDimV / kBlockN); ++n_loop) {
{
// int v_block_buffer_load_global_offset = n_loop*kBlockN;
int v_block_buffer_load_global_offset = WARP_ID * kHeadDimV * WARP_K + n_loop * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
// global->lds, right matrix
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS; // (load / (kBlockN/32));
// int v_warp_buffer_load_n_id = (warp_loop & (kBlockN/32 - 1));
// int v_warp_buffer_load_global_offset = (v_warp_buffer_load_n_id * 32);
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) +*/ (load * READ_ONCE_COUNT);
// int v_gvoffset = (v_block_buffer_load_global_offset + v_warp_buffer_load_globhalf_tal_offset + /*(k_idx*16*M) + (m_idx*32) +*/ (v_lane_n_idx * 2 + v_lane_k_idx * kHeadDim)) / 2;
int v_gvoffset_s = (v_block_buffer_load_global_offset/* + v_warp_buffer_load_global_offset*/) / 2;
int v_gvoffset_v = ((v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load*READ_ONCE_LINES, max_seq_kv_offset - 1) * kHeadDimV)) / 2;
int v_lds_offset = (v_warp_buffer_load_lds_offset) / 2;
BUFFER_LOAD_FUNC(v_lds + (stage_id)*WARP_K*kBlockN, gV, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
if constexpr (STAGES == 2) stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx=0; vec_idx<4; vec_idx++) {
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset)/2) * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
// #pragma unroll
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)"); // 这里暂时先写死是 2, 传编译期参数进去会导致性能略微下降
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if constexpr (STAGES == 2) {
int n_loop = (kHeadDimV / kBlockN) - 1;
stage_id ^= 1;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx=0; vec_idx<4; vec_idx++) {
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset)/2) * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<0>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
// #pragma unroll
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)"); // 这里暂时先写死是 2, 传编译期参数进去会导致性能略微下降
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename Element_k, typename ElementAccum>
__forceinline__ __device__ void int8_kvcache_pv_gemm_prefetch_k_3stage(
vec4_uint gV,
vec4_uint gK,
Element_k* v_lds,
Element_k* k_lds,
float scales_v[2][4],
union_vec2_f16x2<Element> p_reg[(WARP_M/32)*(WARP_K/32)][4],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV/kBlockN)*(WARP_M/32)*(kBlockN/32)][4],
int WARP_ID,
int vcache_seqlen_stride,
int max_seq_kv_offset = -1) {
static_assert(kBlockK>=32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert(kBlockM>=WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert(kBlockN==WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
union_vec2_f16x2<Element> v_reg[STAGES*((32*WARP_N)/(32*32))][4];
union_vec2_int8x2<Element_k> v_reg_int8[STAGES*((32*WARP_N)/(32*32))][4];
// union_vec4_f16x2<int8_t> v_reg_int8[STAGES*((WARP_N*kBlockK*bytes_per_Element/2)/(32*32))*2];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_1 = lane_id >> 1; // 0 ~ 31, 2 个线程读取一行
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 32; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 16; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 1; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_1 & 0xfc); // 第几个4线程组的最小id
int tail = (laneid_shfl_1 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element_k, 2>;
#else
constexpr int READ_ONCE_LINES = 8; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 4; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 7; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_3 & 0x4); // 第几个4线程组的最小id
int tail = (laneid_shfl_3 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_k, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + WARP_ID * STAGES * WARP_K * kBlockN;
constexpr int N_LOOP_START = 0;
// for (int n_loop = N_LOOP_START; n_loop < (kHeadDim / kBlockN); ++n_loop)
int n_loop = 0;
int stage_id = 0;
int precompute_v_lds_offset_int8[4];
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
vec2_Element<Element_k> *v_lds_v2int8 = (vec2_Element<Element_k> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx=0; vec_idx<4; vec_idx++) {
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
precompute_v_lds_offset_int8[vec_idx] = reinterpret_cast<size_t>(v_lds_v2int8) + (stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset);
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
#pragma unroll
for(int ds_idx=0; ds_idx<2; ds_idx++) {
inline_ds_read_b16_no_wait_bytes(precompute_v_lds_offset_int8[vec_idx]+ds_idx*64, v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u16[ds_idx]);
}
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(4)"); // 这里暂时先写死是 0,不然int8的kvache结果会不稳定
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 1;
// stage_id = 1;
{
{
stage_id = 0;
// int v_block_buffer_load_global_offset = n_loop*kBlockN;
int v_block_buffer_load_global_offset = WARP_ID * vcache_seqlen_stride * WARP_K + (n_loop + 2/*now, n_loop = 1 rather than 0*/) * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
// global->lds, right matrix
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS; // (load / (kBlockN/32));
// int v_warp_buffer_load_n_id = (warp_loop & (kBlockN/32 - 1));
// int v_warp_buffer_load_global_offset = (v_warp_buffer_load_n_id * 32);
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) +*/ (load * READ_ONCE_COUNT);
// int v_gvoffset = (v_block_buffer_load_global_offset + v_warp_buffer_load_globhalf_tal_offset + /*(k_idx*16*M) + (m_idx*32) +*/ (v_lane_n_idx * 2 + v_lane_k_idx * kHeadDim)) / 2;
int v_gvoffset_s = (v_block_buffer_load_global_offset/* + v_warp_buffer_load_global_offset*/) / 4;
int v_gvoffset_v = ((v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load*READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride)) / 4;
int v_lds_offset = (v_warp_buffer_load_lds_offset) / 4;
BUFFER_LOAD_FUNC(v_lds + (stage_id)*WARP_K*kBlockN, gV, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
stage_id = 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
// int precompute_v_lds_offset_int8[4];
// vec2_Element<Element_k> *v_lds_v2int8 = (vec2_Element<Element_k> *)(v_lds);
// // lds -> vgpr use ds_read_m; right matrix
// for(int vec_idx=0; vec_idx<4; vec_idx++) {
// for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
// for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
// precompute_v_lds_offset_int8[vec_idx] = reinterpret_cast<size_t>(v_lds_v2int8) + (stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset);
// }
// }
// }
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
#pragma unroll
for(int ds_idx=0; ds_idx<2; ds_idx++) {
// #pragma unroll
inline_ds_read_b16_no_wait_bytes(precompute_v_lds_offset_int8[vec_idx]+WARP_K*kBlockN+ds_idx*64, v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u16[ds_idx]);
}
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(4)"); // 这里暂时先写死是 0,不然int8的kvache结果会不稳定
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 2;
stage_id = 2;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
// int precompute_v_lds_offset_int8[4];
// vec2_Element<Element_k> *v_lds_v2int8 = (vec2_Element<Element_k> *)(v_lds);
// // lds -> vgpr use ds_read_m; right matrix
// for(int vec_idx=0; vec_idx<4; vec_idx++) {
// for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
// for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
// precompute_v_lds_offset_int8[vec_idx] = reinterpret_cast<size_t>(v_lds_v2int8) + (stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset);
// }
// }
// }
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
#pragma unroll
for(int ds_idx=0; ds_idx<2; ds_idx++) {
inline_ds_read_b16_no_wait_bytes(precompute_v_lds_offset_int8[vec_idx]+ds_idx*64+2*WARP_K*kBlockN, v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u16[ds_idx]);
}
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(4)"); // 这里暂时先写死是 0,不然int8的kvache结果会不稳定
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 3;
stage_id = 0;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
// int precompute_v_lds_offset_int8[4];
// vec2_Element<Element_k> *v_lds_v2int8 = (vec2_Element<Element_k> *)(v_lds);
// // lds -> vgpr use ds_read_m; right matrix
// for(int vec_idx=0; vec_idx<4; vec_idx++) {
// for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
// for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
// precompute_v_lds_offset_int8[vec_idx] = reinterpret_cast<size_t>(v_lds_v2int8) + (stage_id*WARP_K*kBlockN + (seq_idx*32*kBlockN) + head_dim_idx*32*32 + vec_idx*8*32 + v_ds_read_offset);
// }
// }
// }
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0>();
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; left matrix
// int v_lane_head_dim_idx = lane_id % 16;
// int v_lane_seq_idx = lane_id >> 4;
// vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(WARP_K/32); seq_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_N/32); head_dim_idx++) {
#pragma unroll
for(int ds_idx=0; ds_idx<2; ds_idx++) {
// #pragma unroll
// inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
inline_ds_read_b16_no_wait_bytes(precompute_v_lds_offset_int8[vec_idx]+ds_idx*64, v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (head_dim_idx*(WARP_K/32) + seq_idx)][vec_idx].u16[ds_idx]);
}
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32*WARP_N)/(32*32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起 (p_vgprs 格式暂且简化, 不考虑下面那个复杂的 m_idx 跟 k_idx)
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(4)"); // 这里暂时先写死是 0,不然int8的kvache结果会不稳定
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 0],
p_reg[0][(0*2 + min_tile_m)].f16[min_tile_k*2 + 1],
p_reg[0][(1*2 + min_tile_m)].f16[min_tile_k*2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
float temp0 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2];
float temp1 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2];
float temp2 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[min_tile_n]) * scales_v[0][min_tile_k*2+1];
float temp3 = static_cast<float>(v_reg_int8[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].i8[2+min_tile_n]) * scales_v[1][min_tile_k*2+1];
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp0);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp1);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n] = DownCast<float, Element, true>(temp2);
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n] = DownCast<float, Element, true>(temp3);
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n] = vec4_Element<Element>{
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][0 + min_tile_k*2].f16x2[1][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[0][min_tile_n],
v_reg[stage_id*((WARP_N*WARP_K)/(32*32)) + (n_idx* (WARP_K/32) + k_idx)][1 + min_tile_k*2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(WARP_K/32); k_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
int n_loop_idx = n_loop;
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx* (WARP_K/32) + k_idx][/*min_tile_k*2 + */min_tile_n],
pv_reg[n_loop_idx * ((WARP_M/32)*(kBlockN/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
// __syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element_k, int STAGES>
__forceinline__ __device__ void int8_kvcache_prefetch_v_to_lds(
vec4_uint gV,
Element_k* v_lds,
int WARP_ID,
int kcache_seqlen_stride,
int max_seq_kv_offset = -1) {
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_1 = lane_id >> 1; // 0 ~ 31, 2 个线程读取一行
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 32; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 16; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 1; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_1 & 0xfc); // 第几个4线程组的最小id
int tail = (laneid_shfl_1 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element_k, 2>;
#else
constexpr int READ_ONCE_LINES = 8; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 4; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 7; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_3 & 0x4); // 第几个4线程组的最小id
int tail = (laneid_shfl_3 & 0x3); // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_k, 2>;
#endif
int n_loop = 0;
if constexpr (STAGES > 1) {
// int v_block_buffer_load_global_offset = n_loop*kBlockN;
int v_block_buffer_load_global_offset = WARP_ID * kcache_seqlen_stride * WARP_K + n_loop*kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
// global->lds, right matrix
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS; // (load / (kBlockN/32));
// int v_warp_buffer_load_n_id = (warp_loop & (kBlockN/32 - 1));
// int v_warp_buffer_load_global_offset = (v_warp_buffer_load_n_id * 32);
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) +*/ (load * READ_ONCE_COUNT);
// int v_gvoffset = (v_block_buffer_load_global_offset + v_warp_buffer_load_globhalf_tal_offset + /*(k_idx*16*M) + (m_idx*32) +*/ (v_lane_n_idx * 2 + v_lane_k_idx * kHeadDim)) / 2;
int v_gvoffset_s = (v_block_buffer_load_global_offset/* + v_warp_buffer_load_global_offset*/) / 4;
int v_gvoffset_v = ((v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load*READ_ONCE_LINES, max_seq_kv_offset - 1) * kcache_seqlen_stride)) / 4;
int v_lds_offset = (v_warp_buffer_load_lds_offset) / 4;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * STAGES * WARP_K * kBlockN + (stage_id)*WARP_K*kBlockN, gV, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES > 2) {
{
// int v_block_buffer_load_global_offset = n_loop*kBlockN;
int v_block_buffer_load_global_offset = WARP_ID * kcache_seqlen_stride * WARP_K + (n_loop + 1)*kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
// global->lds, right matrix
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS; // (load / (kBlockN/32));
// int v_warp_buffer_load_n_id = (warp_loop & (kBlockN/32 - 1));
// int v_warp_buffer_load_global_offset = (v_warp_buffer_load_n_id * 32);
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) +*/ (load * READ_ONCE_COUNT);
// int v_gvoffset = (v_block_buffer_load_global_offset + v_warp_buffer_load_globhalf_tal_offset + /*(k_idx*16*M) + (m_idx*32) +*/ (v_lane_n_idx * 2 + v_lane_k_idx * kHeadDim)) / 2;
int v_gvoffset_s = (v_block_buffer_load_global_offset/* + v_warp_buffer_load_global_offset*/) / 4;
int v_gvoffset_v = ((v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load*READ_ONCE_LINES, max_seq_kv_offset - 1) * kcache_seqlen_stride)) / 4;
int v_lds_offset = (v_warp_buffer_load_lds_offset) / 4;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * 3 * WARP_K * kBlockN + (stage_id + 1)*WARP_K*kBlockN, gV, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
{
// int v_block_buffer_load_global_offset = n_loop*kBlockN;
int v_block_buffer_load_global_offset = WARP_ID * kcache_seqlen_stride * WARP_K + (n_loop + 2)*kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
// global->lds, right matrix
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS; // (load / (kBlockN/32));
// int v_warp_buffer_load_n_id = (warp_loop & (kBlockN/32 - 1));
// int v_warp_buffer_load_global_offset = (v_warp_buffer_load_n_id * 32);
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) +*/ (load * READ_ONCE_COUNT);
// int v_gvoffset = (v_block_buffer_load_global_offset + v_warp_buffer_load_globhalf_tal_offset + /*(k_idx*16*M) + (m_idx*32) +*/ (v_lane_n_idx * 2 + v_lane_k_idx * kHeadDim)) / 2;
int v_gvoffset_s = (v_block_buffer_load_global_offset/* + v_warp_buffer_load_global_offset*/) / 4;
int v_gvoffset_v = ((v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load*READ_ONCE_LINES, max_seq_kv_offset - 1) * kcache_seqlen_stride)) / 4;
int v_lds_offset = (v_warp_buffer_load_lds_offset) / 4;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * 3 * WARP_K * kBlockN + (stage_id + 2)*WARP_K*kBlockN, gV, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
// #endif // end of KVCACHE_USE_4STAGES_PINGPANG
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment