Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#include "mla_pv_gemm_utils_tile16x32.h"
template<int K_LOOP_COUNT, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert(kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert(kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in mla_pv_gemm_prefetch_k must be WARP_N * 32");
union_vec2_f16x2<Element> v_reg[STAGES * PV_K_WARP_COUNT * PV_N_WARP_COUNT][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
constexpr int NEXT_DWORD_OFFSET = 64; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 第二个 dword 偏移 64 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = laneid_shfl_2; // global -> lds, seqlen 方向的坐标
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 按照线程 [0, 16, 32, 48] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = laneid_shfl_4; // 0-15 read row 0; 16-31 read row 1; 32-47 read row 2; 48-63 read row 3
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + warp_id * STAGES * WARP_K * kBlockN;
int stage_id = (STAGES == 2) ? 1: 0;
constexpr int N_LOOP_START = (STAGES == 2) ? K_LOOP_COUNT - 2: K_LOOP_COUNT - 1;
for (int n_loop = N_LOOP_START; n_loop >= 0; --n_loop) {
int v_block_buffer_load_global_offset = n_loop * kBlockN;
#pragma unroll
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES + warp_id * WARP_K, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
{
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[0],
p_reg[0][0 * 2 + min_tile_m].f16[1],
p_reg[0][0 * 2 + min_tile_m].f16[2],
p_reg[0][0 * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0].f16x2[0][min_tile_n],
v_reg[v_tile_id][0].f16x2[1][min_tile_n],
v_reg[v_tile_id][1].f16x2[0][min_tile_n],
v_reg[v_tile_id][1].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop + 1: n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][1 * 2 + min_tile_m].f16[0],
p_reg[0][1 * 2 + min_tile_m].f16[1],
p_reg[0][1 * 2 + min_tile_m].f16[2],
p_reg[0][1 * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][2].f16x2[0][min_tile_n],
v_reg[v_tile_id][2].f16x2[1][min_tile_n],
v_reg[v_tile_id][3].f16x2[0][min_tile_n],
v_reg[v_tile_id][3].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop + 1: n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if constexpr (STAGES == 2) {
int n_loop = 0;
stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<0>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
{
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[0],
p_reg[0][0 * 2 + min_tile_m].f16[1],
p_reg[0][0 * 2 + min_tile_m].f16[2],
p_reg[0][0 * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0].f16x2[0][min_tile_n],
v_reg[v_tile_id][0].f16x2[1][min_tile_n],
v_reg[v_tile_id][1].f16x2[0][min_tile_n],
v_reg[v_tile_id][1].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][1 * 2 + min_tile_m].f16[0],
p_reg[0][1 * 2 + min_tile_m].f16[1],
p_reg[0][1 * 2 + min_tile_m].f16[2],
p_reg[0][1 * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][2].f16x2[0][min_tile_n],
v_reg[v_tile_id][2].f16x2[1][min_tile_n],
v_reg[v_tile_id][3].f16x2[0][min_tile_n],
v_reg[v_tile_id][3].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void mla_prefetch_v_to_lds(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// global -> lds, seqlen 方向的坐标
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
int n_loop = 0;
if constexpr (STAGES > 1) {
int v_block_buffer_load_global_offset = warp_id * WARP_K * kvcache_seqlen_stride + n_loop * kBlockN;
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + warp_id) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + warp_id * STAGES * WARP_K * kBlockN + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void mla_prefetch_v_to_lds_tile16x32(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
constexpr int NEXT_DWORD_OFFSET = 64; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 第二个 dword 偏移 64 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = laneid_shfl_2; // global -> lds, seqlen 方向的坐标
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 按照线程 [0, 16, 32, 48] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = laneid_shfl_4; // 0-15 read row 0; 16-31 read row 1; 32-47 read row 2; 48-63 read row 3
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
int n_loop = (kHeadDim / kBlockN) - 1;
if constexpr (STAGES > 1) {
int v_block_buffer_load_global_offset = n_loop * kBlockN;
#pragma unroll
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES + warp_id * WARP_K, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + warp_id * STAGES * WARP_K * kBlockN + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "mla_qk_gemm_prefetch_v_qinlds.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_qk_gemm_prefetch_v(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2][4],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
union_vec4_f16x2<Element> k_reg[STAGES * (WARP_N * kBlockK) / (32 * 32) * 2];
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = (warp_id & (WARP_N / WARP_N - 1));
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 64 + (laneid_and_15 >> 1) + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
// 初始化 s
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
int stage_id = 0;
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop++) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK + warp_id * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = (stage_id * WARP_NUM + warp_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + warp_id) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
inline_buffer_load_dword_lds(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = (stage_id * WARP_NUM + warp_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * WARP_N * 17 + n_idx * 32 * 17 + j * 4 + i * 32 + k_ds_read_offset) * 4/*4 bytes per dword*/;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
#pragma unroll
for (int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = warp_id * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + last_stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = ((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = ((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
mla_prefetch_v_to_lds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(v_addr, v_lds, warp_id, kvcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "mla_qk_gemm_utils.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int N_WARP_COUNT, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_qk_gemm_prefetch_v_qinlds(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
vec2_Element<Element> q_reg[1 * M_WARP_COUNT * (kBlockK / 32) * 2][4],
vec4_Accum<ElementAccum> s_reg[M_WARP_COUNT * N_WARP_COUNT][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
constexpr int WARP_M = M_WARP_COUNT * 32;
constexpr int WARP_N = N_WARP_COUNT * 32;
constexpr int K_WARP_COUNT = kBlockK / 32;
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
static_assert(STAGES == 1 and "For mla_qk_gemm_prefetch_v_qinlds, only depth 1 is supported");
union_vec4_f16x2<Element> k_reg[STAGES * N_WARP_COUNT * K_WARP_COUNT * 2];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = (warp_id & (kBlockN / WARP_N - 1));
int k_ds_read_offset = k_warp_n_id * N_WARP_COUNT * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
// initial qk gemm results as 0
__builtin_amdgcn_sched_barrier(0);
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (kBlockN / WARP_N) * M_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
__builtin_amdgcn_sched_barrier(0);
// k loop across kblockN
int stage_id = 0;
constexpr int K_LOOP_START = 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop++) {
{
int k_block_buffer_load_global_offset = k_loop * kBlockK + warp_id * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = (stage_id * WARP_NUM + warp_id) * N_WARP_COUNT * K_WARP_COUNT * (32 * 34);
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + warp_id) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32));
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = (min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1) * kvcache_seqlen_stride) / 2 + qk_lane_head_dim_idx;
inline_buffer_load_dword_lds(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
int q_lds_stage_offset = k_loop * (kBlockM / 32) * K_WARP_COUNT * (32 * 17) >> 1;
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < K_WARP_COUNT; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; m_idx++) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * (kBlockM >> 1) * 17 + j * 2 + i * 32 + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 64 + (laneid_and_15 >> 1/*padding*/) + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
int is_useless = (lane_id >> 3) & 1;
lds_offset = (is_useless) ? 0: lds_offset;
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[0 * M_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * M_WARP_COUNT + m_idx) * 2 + i][j]);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = warp_id * N_WARP_COUNT * K_WARP_COUNT * (32 * 17) + stage_id * WARP_NUM * N_WARP_COUNT * K_WARP_COUNT * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < K_WARP_COUNT; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
buffer_load_lds_dwordx1_wait_nosync<0>();
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < K_WARP_COUNT; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
#pragma unroll
for (int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * N_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * N_WARP_COUNT + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < K_WARP_COUNT; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = 0 * M_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * M_WARP_COUNT + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * N_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * N_WARP_COUNT + n_idx) * 2 + min_tile_n;
s_reg[n_idx * M_WARP_COUNT + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * M_WARP_COUNT + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < K_WARP_COUNT; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = 0 * M_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * M_WARP_COUNT + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * N_WARP_COUNT * K_WARP_COUNT * 2 + (head_dim_idx * N_WARP_COUNT + n_idx) * 2 + min_tile_n;
s_reg[n_idx * M_WARP_COUNT + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id][2 * min_tile_k][0],
q_reg[q_tile_id][2 * min_tile_k][1],
q_reg[q_tile_id][2 * min_tile_k + 1][0],
q_reg[q_tile_id][2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * M_WARP_COUNT + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
} // qk_gemm
#pragma once
#include "mla_qk_gemm_utils_tile16x32.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void mla_qk_gemm_prefetch_v_tile16x32(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
union_vec4_f16x2<Element> k_reg[1 * (WARP_N * kBlockK) / (32 * 32) * 2];
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 16;
#else // zd
int qk_lane_m_idx = laneid_shfl_4;
int qk_lane_head_dim_idx = laneid_and_15;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 4;
#endif
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (READ_ONCE_LINES * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
int k_warp_n_id = (warp_id & (WARP_N / WARP_N - 1));
// 0,0,32,32,0,0,32,32 | 0,0,32,32,0,0,32,32 | 16,16,48,48,16,16,48,48 | 16,16,48,48,16,16,48,48
// (lane_id & 1) * 16: in seqlen direction, [0,1,0,1,2,3,2,3], odd threads need skip 32 Halfs, 16 dword
// (laneid_and_15 >> 1) * 64: threads 0,1 occupy 4 lines, 4x32 Halfs, 64 dword.... 2,3 and 4,5 and 6,7 is the same
// laneid_and_15 >> 1, padding
// (laneid_shfl_4 & 1) * 8: threads 0,32 is even times of 16, thus 0,32; threads 16,48 is odd times of 16, thus 0,32,16,48; 0->16 need skip 16 Halfs, 8 dword
// (lane_id / 32): 0,0,32,32,0,0,32,32, 0->32, skip 2 Halfs, 1 dword
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 16) + laneid_and_15 * 16 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32) * 4;
// 初始化 s
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
int stage_id = 0;
#ifdef MLA_K_LOAD_32x64_BLOCKS
constexpr int K_LOAD_BLOCKS = 2;
#else
constexpr int K_LOAD_BLOCKS = 1;
#endif
constexpr int K_LOOP_START = (STAGES == 2) ? K_LOAD_BLOCKS: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop += K_LOAD_BLOCKS) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK/*offset in headdim direction*/;
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + stage_id * K_LOAD_BLOCKS) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32));
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
#ifdef MLA_K_LOAD_32x64_BLOCKS
if constexpr (true) {
int k_block_buffer_load_global_offset = (k_loop + 1) * kBlockK/*offset in headdim direction*/;
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + stage_id * K_LOAD_BLOCKS + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32));
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
#endif
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2];
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + stage_id * K_LOAD_BLOCKS) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * WARP_N * 16 + n_idx * 32 * 16 + i * 16 * 16 + k_ds_read_offset;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
#ifdef MLA_K_LOAD_32x64_BLOCKS
buffer_load_lds_dwordx1_wait_nosync<3 * K_LOAD_REQUESTS>();
#else
buffer_load_lds_dwordx1_wait_nosync<1 * K_LOAD_REQUESTS>();
#endif
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - K_LOAD_BLOCKS: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - K_LOAD_BLOCKS: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
#ifdef MLA_K_LOAD_32x64_BLOCKS
{
int precompute_k_lds_offset[2];
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + stage_id * K_LOAD_BLOCKS + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * WARP_N * 16 + n_idx * 32 * 16 + i * 16 * 16 + k_ds_read_offset;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
#endif
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
constexpr int k_loop = kHeadDim / kBlockK;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + last_stage_id * K_LOAD_BLOCKS) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * (WARP_N * 16) + n_idx * (32 * 16) + i * 16 * 16 + k_ds_read_offset;
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
#ifdef MLA_K_LOAD_32x64_BLOCKS
buffer_load_lds_dwordx1_wait_nosync<1 * K_LOAD_REQUESTS>();
#else
buffer_load_lds_dwordx1_wait_nosync<0 * K_LOAD_REQUESTS>();
#endif
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - K_LOAD_BLOCKS) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - K_LOAD_BLOCKS) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
#ifdef MLA_K_LOAD_32x64_BLOCKS
{
int precompute_k_lds_offset[2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (warp_id * STAGES * K_LOAD_BLOCKS + last_stage_id * K_LOAD_BLOCKS + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * (WARP_N * 16) + n_idx * (32 * 16) + i * 16 * 16 + k_ds_read_offset;
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
#endif
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
mla_prefetch_v_to_lds_tile16x32<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(v_addr, v_lds, warp_id, kvcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "mla_pv_gemm_utils.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int REUSE_KV_TIMES, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_prefetch_q_to_vgpr(
vec4_uint q_addr,
Element* q_lds,
vec2_Element<Element> q_reg[(kHeadDim / kBlockK) * ((WARP_M * kBlockK) / (32 * 32)) * 2][4],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset = -1) {
constexpr bool MMAC_32x32 = M_MMAC_COUNT > 1;
constexpr int Q_LOAD_REQUESTS = (REUSE_KV_TIMES == 0) ? (kBlockM * kBlockK) / (4 * 32 * WARP_NUM): MMAC_32x32 ? ((REUSE_KV_TIMES + 1) >> 1) << 2 / WARP_NUM: 1/*MHA only need the first token*/;
constexpr int SEQUENCE_READ = MMAC_32x32 ? 2: 1;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = ((lane_id >> 4) & 1) * 2 + ((lane_id >> 4) >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int q_lane_head_dim_idx = lane_id & 15;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int q_ds_read_offset = (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 64 + (laneid_and_15 >> 1) + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
int stage_id = 0;
if constexpr (STAGES > 1) {
int k_loop = 0;
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (offset_v < max_seq_q_offset) {
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2;
offset_v = (offset_v * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + ((q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32));
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (offset_v < max_seq_q_offset) {
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2;
offset_v = (offset_v * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + j * 2 + i * 32 + q_ds_read_offset;
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i][j]);
}
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + j * 2 + i * 32 + q_ds_read_offset;
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i][j]);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
}
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_NUM, typename Element, int STAGES, int REUSE_KV_TIMES, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_prefetch_q_to_lds_stage1(
vec4_uint q_addr,
Element* q_lds,
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset = -1) {
constexpr bool MMAC_32x32 = M_MMAC_COUNT > 1 and REUSE_KV_TIMES > 0; /*mtp <= 16 的时候, 也是走这条路径, 只需要加载 16x576 个 Half 到 lds*/
constexpr int Q_LOAD_REQUESTS = (REUSE_KV_TIMES == 0) ? (kBlockM * kBlockK) / (4 * 32 * WARP_NUM): MMAC_32x32 ? ((REUSE_KV_TIMES + 1) >> 1) << 2 / WARP_NUM: 1/*MHA only need the first token*/;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = ((lane_id >> 4) & 1) * 2 + ((lane_id >> 4) >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int q_lane_head_dim_idx = lane_id & 15;
constexpr int K_LOOP_START = 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop++) {
// global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = k_loop * (kBlockM / 32) * (kBlockK / 32) * (32 * 17) >> 1;
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (offset_v < max_seq_q_offset) {
int q_warp_buffer_load_lds_offset = (q_warp_buffer_load_m_id >> 3) * (32 * 17) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int lds_offset = q_lds_stage_offset + (q_warp_buffer_load_lds_offset >> 1) + (warp_loop & 7/*padding*/);
offset_v = (offset_v * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
}
template<int kBlockK, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void mla_prefetch_k_to_lds(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
// compute offset pattern
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
// decide buffer load func
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
int stage_id = 0;
int k_loop = 0;
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = k_loop * kBlockK + warp_id * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = (stage_id * WARP_NUM + warp_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + warp_id) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "mla_pv_gemm_utils_tile16x32.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT>
__forceinline__ __device__ void mla_prefetch_q_to_vgpr_tile16x32(
vec4_uint q_addr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * ((WARP_M * kBlockK) / (32 * 32)) * 2],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset=-1) {
#if defined(__gfx928__)
constexpr int Q_LOAD_REQUESTS = (kBlockM * kBlockK >> 1/*16x32 tile*/) * M_MMAC_COUNT / (4 * 32 * WARP_NUM);
constexpr int SEQUENCE_READ = M_MMAC_COUNT;
constexpr int READ_ONCE_LINES = 4;
auto BUFFER_LOAD_FUNC = &builtin_buffer_load_dword_lds<Element, float, 1>; // buffer_load_dwordx4 can also be applied if necessary
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = lane_id >> 4;
int q_lane_head_dim_idx = lane_id & 15;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int q_ds_read_offset = laneid_and_15 * 16 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32) * 4;
int stage_id = 0;
if constexpr (STAGES > 1) {
int k_loop = 0;
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + ((q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32));
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset;
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
q_reg[k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f32 = *(vec4_fp32*)(q_lds_v2fp16 + lds_offset);
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset;
q_reg[((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f32 = *(vec4_fp32*)(q_lds_v2fp16 + lds_offset);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
#elif defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int lane_id = threadIdx.x & 63;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int q_lds_thread_offset = laneid_and_15 * 16 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32) * 4;
// 尝试用 buffer_load_dwordx4
for (int load_loop = warp_id; load_loop < kHeadDim / kBlockK; load_loop += WARP_NUM) {
// 写到 lds 的地址
int lds_offset = (load_loop * 16 * 32) >> 1;
// warp 的读取地址
int warp_offset = (load_loop * 32) >> 1;
// 精确到线程的地址
int row_idx = lane_id >> 2;
int thread_offset = (min(row_idx, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + (lane_id & 3) * 4/*8 个 half*/;
builtin_buffer_load_dword_lds<Element, float, 4>(q_lds, q_addr, lds_offset, warp_offset, thread_offset);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
#pragma unroll
for (int lds_loop = 0; lds_loop < kHeadDim / kBlockK; ++lds_loop) {
int q_lds_warp_offset = lds_loop * 16 * 32 / 2;
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
q_reg[lds_loop * 2].f32 = *(vec4_fp32*)(q_lds_v2fp16 + q_lds_warp_offset + q_lds_thread_offset);
}
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
// #define MLA_K_LOAD_32x64_BLOCKS
template<int kBlockK, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void mla_prefetch_k_to_lds_tile16x32(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 16;
#else // zd
int qk_lane_m_idx = lane_id >> 4;
int qk_lane_head_dim_idx = lane_id & 15;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 4;
#endif
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (READ_ONCE_LINES * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
int stage_id = 0;
int k_loop = 0;
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
#ifdef MLA_K_LOAD_32x64_BLOCKS
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#else
int k_lds_stage_offset = (warp_id * STAGES * 1 + stage_id * 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#endif
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32);
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef MLA_K_LOAD_32x64_BLOCKS
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = (k_loop + 1) * kBlockK;
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32);
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
#endif
}
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void mla_apply_mask(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT, int MTP_REGROUP_COUNT, int REUSE_KV_TIMES>
inline __device__ void mla_apply_mask_causal(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int mtp, const int layout) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
int col_idx_limit_right;
if constexpr (REUSE_KV_TIMES == 0) {
col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q);
} else {
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / MTP_REGROUP_COUNT);
col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
}
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
__device__ inline void mla_thread_reduce_max(const DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = -INFINITY;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
__device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = 0;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
} else {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
}
}
template<typename Operator, typename DataType, int M_WARP_COUNT>
__device__ inline void mla_quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; mi++) {
dst[mi] = Allreduce<64>::run(src[mi], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
__device__ inline void mla_reduce_(const DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if constexpr (OpType == 0) { // sum
if constexpr (zero_init == true) {
mla_thread_reduce_sum<true, Operator, 0, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, summary, op);
mla_quad_allreduce_<Operator, DataType1, M_WARP_COUNT>(summary, summary, op);
} else {
mla_thread_reduce_sum<false, Operator, 0, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
mla_quad_allreduce_<Operator, DataType1, M_WARP_COUNT>(summary_cur, summary_cur, op);
}
} else if constexpr (OpType == 1) { // max
if constexpr (zero_init == true) {
mla_thread_reduce_max<true, Operator, 1, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, summary, op);
mla_quad_allreduce_<Operator, DataType1, M_WARP_COUNT>(summary, summary, op);
} else {
mla_thread_reduce_max<false, Operator, 1, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
mla_quad_allreduce_<Operator, DataType1, M_WARP_COUNT>(summary_cur, summary_cur, op);
}
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
__device__ inline void mla_reduce_max(const DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if constexpr (zero_init == true) {
mla_reduce_<true, MaxOp<float>, 1, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, max, max_op);
} else {
mla_reduce_<false, MaxOp<float>, 1, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
__device__ inline void mla_reduce_sum(DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], DataType1 *sum, DataType1 *sum_cur=nullptr) {
SumOp<float> sum_op;
if constexpr (zero_init == true) {
mla_reduce_<true, SumOp<float>, 0, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, sum, sum_op);
} else {
mla_reduce_<false, SumOp<float>, 0, DataType0, DataType1, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename DataType0, typename DataType1, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WARP_COUNT][4], const DataType1 *max, const float scale) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const float max_scaled = (max[mi].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
__float2 neg_max_scaled_pair = {-max_scaled, -max_scaled};
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
#else
for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
template<bool Check_inf=false, typename softmaxType, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int N_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT>
inline __device__ void mla_softmax_rescale_o(
vec4_Accum<softmaxType> scores[N_WARP_COUNT * M_WARP_COUNT][4],
vec2_Accum<softmaxType> *scores_max,
vec2_Accum<softmaxType> *scores_sum,
vec4_Accum<softmaxType> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
softmaxType* max_lds,
int warp_id,
float softmax_scale_log2) {
static_assert (std::is_same<softmaxType, float>::value and "For softmax after QK gemm, only float32 is supported!");
// 求当前 32x32 的最大值, 以及和前面计算得到的最大值
vec2_Accum<softmaxType> scores_max_cur[M_WARP_COUNT];
mla_reduce_max</*zero_init=*/false, vec4_Accum<softmaxType>, vec2_Accum<softmaxType>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
int lane_id = threadIdx.x & 63;
constexpr int WARP_M = M_WARP_COUNT * 32;
if constexpr (WARP_NUM > 1) {
static_assert (M_WARP_COUNT == 1);
int dword_offset_base = (lane_id & 15);
if (lane_id < 16) {
if (warp_id == 0) {
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
max_lds[dword_offset_base + m_loop * 32] = -INFINITY;
}
}
__syncthreads();
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
__builtin_amdgcn_ds_fmaxf((__attribute__((address_space(3))) float *)max_lds + dword_offset_base + m_loop * 32, scores_max_cur[0].f32[m_loop], 0, 0, false);
}
}
__syncthreads();
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
scores_max_cur[0].f32[m_loop] = max_lds[dword_offset_base + m_loop * 32];
}
}
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi].f32[min_tile_m]
: (scores_max_cur[mi].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi].f32[min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi].f32[min_tile_m] *= scores_scale;
__float2 scores_scale_pair = {scores_scale, scores_scale};
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
int loop_id = (pv_n_loop * K_WARP_COUNT + ni) * M_WARP_COUNT + mi;
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
}
#else
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
acc_o[loop_id][min_tile_n * 2 + min_tile_m].f32[vec_idx] *= scores_scale;
}
#endif
}
}
}
}
}
mla_scale_apply_exp2</*zero_init=*/true, vec4_Accum<softmaxType>, vec2_Accum<softmaxType>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(scores, scores_max_cur, softmax_scale_log2);
vec2_Accum<softmaxType> scores_sum_cur[M_WARP_COUNT];
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
scores_sum_cur[mi].u64 = 0x0;
}
mla_reduce_sum</*zero_init=*/true, vec4_Accum<softmaxType>, vec2_Accum<softmaxType>, M_WARP_COUNT, N_WARP_COUNT, M_MMAC_COUNT>(scores, scores_sum_cur);
if constexpr (WARP_NUM > 1) {
// sum 无法用 ds_atomic_add_f32, 因为 non-desterminstic
softmaxType* sum_lds = max_lds + 64;
if(lane_id < 16) {
// 每个 wave 的归一化和写到 lds
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
sum_lds[warp_id * WARP_M + mi * 32 + lane_id * 2] = scores_sum_cur[mi].f32[0];
} else {
*(__float2*)(sum_lds + warp_id * WARP_M + mi * 32 + lane_id * 2) = scores_sum_cur[mi].u64;
}
}
__syncthreads();
// 0 号 wave reduce 其他 wave 的归一化和
if (warp_id == 0) {
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
float tmp = sum_lds[mi * 32 + lane_id * 2];
for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) {
tmp += sum_lds[warp_loop * WARP_M + mi * 32 + lane_id * 2];
}
sum_lds[mi * 32 + lane_id * 2] = tmp;
} else {
__float2 cur_wave_sum = *(__float2*)(sum_lds + mi * 32 + lane_id * 2);
#pragma unroll
for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
#endif
}
*(__float2*)(sum_lds + mi * 32 + lane_id * 2) = cur_wave_sum;
}
}
}
}
__syncthreads();
// 4 个 wave 从 lds 中读取最后 reduce 的归一化和
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
scores_sum_cur[mi].f32[0] = sum_lds[mi * 32 + (lane_id & 15) * 2];
} else {
scores_sum_cur[mi].u64 = *(__float2*)(sum_lds + mi * 32 + (lane_id & 15) * 2);
}
}
}
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
// #######################################################
scores_max[mi].u64 = scores_max_cur[mi].u64;
#else
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
// #######################################################
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
#endif
}
};
template <int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT, typename Element, typename ElementAccum>
inline __device__ void mla_convert_pk_type(union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * N_WARP_COUNT][4], union_vec4_fp32 s_reg[M_WARP_COUNT * N_WARP_COUNT][4]) {
#pragma unroll
for (int n_idx = 0; n_idx < N_WARP_COUNT; ++n_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
p_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
#else
p_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * M_WARP_COUNT + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * M_WARP_COUNT + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
}
}
}
#pragma once
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void mla_apply_mask_tile16x32(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT, int MTP_REGROUP_COUNT, int REUSE_KV_TIMES>
inline __device__ void mla_apply_mask_causal_tile16x32(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int mtp, const int layout) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + lane_id & 15;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
int col_idx_limit_right;
if constexpr (REUSE_KV_TIMES == 0) {
col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q);
} else {
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / MTP_REGROUP_COUNT);
col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
}
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
#pragma once
#include "hip/hip_fp16.h"
#include "hip/hip_bf16.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp8.h"
using bhalf_t = __hip_bfloat16;
using half_t = __half;
using fp8_e4m3 = __hip_fp8_e4m3;
using fp8_e5m2 = __hip_fp8_e5m2;
using BFloat16 = bhalf_t;
using Float16 = half_t;
using Int32 = int;
using Int16 = unsigned short;
using Float32 = float;
using f8_t = uint8_t;
//fp8_e4m3 definitions
struct alignas(1) Float8_e4m3_t{
/// Data container
uint8_t data;
__host__ __device__ Float8_e4m3_t() = default;
__host__ __device__ Float8_e4m3_t(uint8_t value): data(value) {}
};
typedef short __attribute__((ext_vector_type(8))) vec8_bf16;
typedef _Float16 __attribute__((ext_vector_type(8))) vec8_fp16;
using vec4_fp16 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
// using vec8_fp16 = __attribute__((__vector_size__(8 * sizeof(_Float16)))) _Float16;
using vec2_fp16 = __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16;
using vec16_fp16 = __attribute__((__vector_size__(16 * sizeof(_Float16)))) _Float16;
using vec4_bf16 = __attribute__((__vector_size__(4 * sizeof(unsigned short)))) unsigned short;
// using vec8_bf16 = __attribute__((__vector_size__(8 * sizeof(unsigned short)))) unsigned short;
using vec2_bf16 = __attribute__((__vector_size__(2 * sizeof(unsigned short)))) unsigned short;
using vec16_bf16 = __attribute__((__vector_size__(16 * sizeof(unsigned short)))) unsigned short;
using vec4_uint = __attribute__((__vector_size__(4 * sizeof(uint32_t)))) uint32_t;
using vec2_uint = __attribute__((__vector_size__(2 * sizeof(uint32_t)))) uint32_t;
using vec4_int = __attribute__((__vector_size__(4 * sizeof(uint32_t)))) int32_t;
using vec4_fp8 = __attribute__((__vector_size__(4 * sizeof(uint8_t)))) uint8_t;
using vec8_fp8 = __attribute__((__vector_size__(8 * sizeof(uint8_t)))) uint8_t;
using vec2_fp8 = __attribute__((__vector_size__(2 * sizeof(uint8_t)))) uint8_t;
using vec16_fp8 = __attribute__((__vector_size__(16 * sizeof(uint8_t)))) uint8_t;
using vec4_int8 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
using vec8_int8 = __attribute__((__vector_size__(8 * sizeof(int8_t)))) int8_t;
using vec2_int8 = __attribute__((__vector_size__(2 * sizeof(int8_t)))) int8_t;
using vec16_int8 = __attribute__((__vector_size__(16 * sizeof(int8_t)))) int8_t;
using vec4_int32 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using __builtin_half2 = __attribute__((ext_vector_type(2))) __fp16;
using __float2 = __attribute__((ext_vector_type(2))) float;
using vec4_fp32 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using vec2_fp32 = __attribute__((__vector_size__(2 * sizeof(float)))) float;
union union_vec4_fp32 {
vec4_fp32 f32;
double data[2];
int64_t b64[2];
__float2 u64[2];
vec2_fp32 f32x2[2];
half_t f16[8];
};
union union_vec2_fp32 {
vec2_fp32 f32;
double data;
__float2 u64;
};
union union_vec_fp32 {
float f32[1];
};
union union_vec4_uint {
unsigned long long u64[2]; // 128 bits
uint4 u32;
uint8_t u8[16];
};
union union_vec2_uint {
uint2 u32;
unsigned long long u64;
};
union union_vec8_fp32 {
float f32[8];
vec4_fp32 f32x4[2];
double data[4];
__float2 u64[4];
};
union union_vec4_int32 {
vec4_int32 int32;
int64_t b64[2];
__float2 u64[2];
};
template<typename T>
using vec2_Element =
std::conditional_t
<
std::is_same_v<T, half_t>,
vec2_fp16,
std::conditional_t
<
std::is_same_v<T, bhalf_t>,
vec2_bf16,
std::conditional_t
<
std::is_same_v<T, Float8_e4m3_t> || std::is_same_v<T, int8_t>,
vec2_fp8,
std::conditional_t
<
std::is_same_v<T, int8_t>,
vec2_int8,
vec2_fp32
>
>
>
>;
template<typename T>
using vec4_Element =
std::conditional_t
<
std::is_same_v<T, half_t>,
vec4_fp16,
std::conditional_t
<
std::is_same_v<T, bhalf_t>,
vec4_bf16,
std::conditional_t
<
std::is_same_v<T, Float8_e4m3_t>,
vec4_fp8,
std::conditional_t
<
std::is_same_v<T, int8_t>,
vec4_int8,
void
>
>
>
>;
template<typename T>
using vec8_Element =
std::conditional_t
<
std::is_same_v<T, half_t>,
vec8_fp16,
std::conditional_t
<
std::is_same_v<T, bhalf_t>,
vec8_bf16,
std::conditional_t
<
std::is_same_v<T, Float8_e4m3_t>,
vec8_fp8,
std::conditional_t
<
std::is_same_v<T, int8_t>,
vec8_int8,
void
>
>
>
>;
template<typename Element>
union union_vec2_f16x2 {
vec2_fp32 f32;
double data;
__float2 u64;
vec2_Element<Element> f16x2[2];
vec4_Element<Element> f8x4[2];
vec4_Element<Element> f16x4;
Element f16[4];
Element f8[8];
int32_t i32[2];
};
template<typename Element>
union union_vec2_f8x2 {
// vec2_u16 u16;
float data;
unsigned short u16[2];
vec2_Element<Element> i8x2[2];
vec4_Element<Element> i8x4;
Element i8[4];
};
template<typename Element>
union union_vec4_f16x2 {
vec4_fp32 f32;
double data[2];
__float2 u64[2];
vec2_Element<Element> f16x2[4];
vec4_Element<Element> f16x4[2];
vec8_Element<Element> f16x8;
vec4_int8 f8x4[4];
Element f16[8];
__builtin_half2 b16x2[4];
};
template<typename Element>
union union_vec2_int8x2 {
// vec2_u16 u16;
float data;
unsigned short u16[2];
vec2_Element<Element> i8x2[2];
vec4_Element<Element> i8x4;
Element i8[4];
};
union union_vec16_fp8 {
int8_t i8[16];
vec8_int8 i8x8[2];
vec4_int8 i8x4[4];
vec4_fp32 f32x4;
vec4_int i32x4;
__float2 u64[2];
int32_t i32[4];
};
union union_vec32_fp8 {
int8_t i8[32];
vec8_int8 i8x8[4];
vec4_int8 i8x4[8];
vec4_int i32x4[2];
int32_t i32[8];
};
template<typename T>
using vec4_Accum = std::conditional_t<std::is_same_v<T, float>, union_vec4_fp32, vec4_bf16>;
template<typename T>
using vec2_Accum = std::conditional_t<std::is_same_v<T, float>, union_vec2_fp32, vec4_bf16>;
template<typename T>
using vec_Accum = std::conditional_t<std::is_same_v<T, float>, union_vec_fp32, vec2_bf16>;
template<typename T>
__forceinline__ __device__ vec4_Element<T> make_vec4_f16(T a, T b, T c, T d) {
return {a, b, c, d};
}
template<>
__forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_t b, bhalf_t c, bhalf_t d) {
#ifdef ROCM_5_7
return {a.data, b.data, c.data, d.data};
#else
return {__hip_bfloat16_raw(a).x, __hip_bfloat16_raw(b).x, __hip_bfloat16_raw(c).x, __hip_bfloat16_raw(d).x};
// return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)};
#endif
}
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
#pragma once
#include "numeric_types.h"
namespace flash {
struct ull2 {
unsigned long long x;
unsigned long long y;
};
__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 res;
asm ("v_mul_lo_u32 %0, %2, %3;\n\t"
"v_mul_hi_u32 %1, %2, %3;\n\t"
: "=v"(res.x), "=v"(res.y)
: "v"(a), "v"(b));
return res;
}
__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
__forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
constexpr unsigned long kPhilox10B = 0xBB67AE85;
uint2 key = reinterpret_cast<uint2&>(seed);
uint4 counter;
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset;
tmp->y = subsequence;
#pragma unroll
for (int i = 0; i < 6; i++) {
counter = philox_single_round(counter, key);
key.x += (kPhilox10A);
key.y += (kPhilox10B);
}
uint4 output = philox_single_round(counter, key);
return output;
}
} // namespace flash
#pragma once
#define PA_FIX_PARTITION 65536
#define MLA_FIX_PARTITION 65536
#define MLA_MAX_SPLITS 1024
#define MLA_MAX_SPLITS_INV 0.0009765625f
#define MLA_FIX_BALANCE_FACTOR 1.5f
template<int MIN_PARTITION_SIZE=128>
__forceinline__ __device__ int splitkv_get_partitionsize_of_fix_numsplits(int actual_seqlen_k, int num_splits) {
float true_partition = max(1.f, actual_seqlen_k / float(num_splits));
int partition_size = 1 << (int(log2f(true_partition - MLA_MAX_SPLITS_INV/*num_splits <= 1024*/)) + 1);
while (num_splits * partition_size > MLA_FIX_BALANCE_FACTOR * actual_seqlen_k and num_splits * (partition_size - MIN_PARTITION_SIZE) > actual_seqlen_k)
partition_size -= MIN_PARTITION_SIZE;
partition_size = max(partition_size, MIN_PARTITION_SIZE);
return partition_size;
}
\ No newline at end of file
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
#include "numeric_types.h"
#include "splitkv.h"
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = Float16; \
return __VA_ARGS__(); \
} else { \
using elem_type = BFloat16; \
return __VA_ARGS__(); \
} \
}()
#define LAYOUT_SWITCH(layout, ...) \
[&] { \
if (layout == 0) { \
constexpr static int Layout = 0; \
return __VA_ARGS__(); \
} else if (layout == 1) { \
constexpr static int Layout = 1; \
return __VA_ARGS__(); \
} \
}()
// #define ElementType_SWITCH(is_bf16, is_e4m3, ...) \
// [&] { \
// if (is_bf16) { \
// using elem_type = BFloat16; \
// return __VA_ARGS__(); \
// } else if (is_e4m3) { \
// using elem_type = Float8_e4m3_t; \
// return __VA_ARGS__(); \
// } else { \
// using elem_type = Float16; \
// return __VA_ARGS__(); \
// } \
// }()
#define ElementType_SWITCH(is_bf16, is_e4m3, ...) \
[&] { \
if (is_bf16) { \
using elem_type = BFloat16; \
return __VA_ARGS__(); \
} else if (is_e4m3) { \
printf("fa bwd does not support fp8 yet");\
} else { \
using elem_type = Float16; \
return __VA_ARGS__(); \
} \
}()
#define HEADDIM_SWITCH(HEADDIMQ, HEADDIMV, ...) \
[&] { \
if (HEADDIMQ <= 32) { \
constexpr static int kHeadDimQ = 32; \
constexpr static int kHeadDimV = 32; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 64) { \
constexpr static int kHeadDimQ = 64; \
constexpr static int kHeadDimV = 64; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 96) { \
constexpr static int kHeadDimQ = 96; \
constexpr static int kHeadDimV = 96; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 128) { \
constexpr static int kHeadDimQ = 128; \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 160) { \
constexpr static int kHeadDimQ = 160; \
constexpr static int kHeadDimV = 160; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 192) { \
constexpr static int kHeadDimQ = 192; \
if (HEADDIMV <= 128) { \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMV <= 192) { \
constexpr static int kHeadDimV = 192; \
return __VA_ARGS__(); \
} \
} else if (HEADDIMQ <= 224) { \
constexpr static int kHeadDimQ = 224; \
constexpr static int kHeadDimV = 224; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 256) { \
constexpr static int kHeadDimQ = 256; \
constexpr static int kHeadDimV = 256; \
return __VA_ARGS__(); \
} \
}()
#define ALL_HEADDIM_SWITCH(HEADDIMQ, HEADDIMV, ...) \
[&] { \
if (HEADDIMQ <= 32) { \
constexpr static int kHeadDimQ = 32; \
constexpr static int kHeadDimV = 32; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 64) { \
constexpr static int kHeadDimQ = 64; \
constexpr static int kHeadDimV = 64; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 96) { \
constexpr static int kHeadDimQ = 96; \
constexpr static int kHeadDimV = 96; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 128) { \
constexpr static int kHeadDimQ = 128; \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 160) { \
constexpr static int kHeadDimQ = 160; \
constexpr static int kHeadDimV = 160; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 192) { \
constexpr static int kHeadDimQ = 192; \
if (HEADDIMV <= 128) { \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMV <= 192) { \
constexpr static int kHeadDimV = 192; \
return __VA_ARGS__(); \
} \
} else if (HEADDIMQ <= 224) { \
constexpr static int kHeadDimQ = 224; \
constexpr static int kHeadDimV = 224; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 256) { \
constexpr static int kHeadDimQ = 256; \
constexpr static int kHeadDimV = 256; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 512) { \
constexpr static int kHeadDimQ = 512; \
constexpr static int kHeadDimV = 512; \
return __VA_ARGS__(); \
} \
}()
#define WARP_ID_SWITCH(warp_id, ...) \
[&] { \
if (warp_id == 0) { \
constexpr int WARP_ID = 0; \
return __VA_ARGS__(); \
} else if (warp_id == 1) { \
constexpr int WARP_ID = 1; \
return __VA_ARGS__(); \
} else if (warp_id == 2) { \
constexpr int WARP_ID = 2; \
return __VA_ARGS__(); \
} else if (warp_id == 3) { \
constexpr int WARP_ID = 3; \
return __VA_ARGS__(); \
} \
}()
#define CU_SWITCH(cu_count, ...) \
[&] { \
if (cu_count == 120) { \
constexpr int CU_COUNT = 120; \
return __VA_ARGS__(); \
} else if (cu_count == 128) { \
constexpr int CU_COUNT = 128; \
return __VA_ARGS__(); \
} else if (cu_count == 88) { \
constexpr int CU_COUNT = 88; \
return __VA_ARGS__(); \
} else if (cu_count == 80) { \
constexpr int CU_COUNT = 80; \
return __VA_ARGS__(); \
} \
}()
#define M_MMAC_COUNT_SWITCH(COND, M_MMAC_COUNT, ...) \
[&] { \
if (COND) { \
constexpr static int M_MMAC_COUNT = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int M_MMAC_COUNT = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH(seqlen_q , ...) \
[&] { \
if (seqlen_q == 16) { \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
} else if (seqlen_q == 13) { \
constexpr static int REUSE_KV_TIMES = 13; \
return __VA_ARGS__(); \
} else if (seqlen_q == 8) { \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
} else if (seqlen_q == 7) { \
constexpr static int REUSE_KV_TIMES = 7; \
return __VA_ARGS__(); \
} else if (seqlen_q == 4) { \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (seqlen_q == 3) { \
constexpr static int REUSE_KV_TIMES = 3; \
return __VA_ARGS__(); \
} else if (seqlen_q == 2) { \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else if (seqlen_q == 1) { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} else if (seqlen_q == 29) { \
constexpr static int REUSE_KV_TIMES = 29; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 0; \
return __VA_ARGS__(); \
} \
}()
#define PA_GQA_REGROUP_SWITCH(ngroups, ...) \
[&] { \
if (ngroups == 8) { \
int GQA_REGROUP = 8; \
return __VA_ARGS__(); \
} else if (ngroups == 4) { \
int GQA_REGROUP = 4; \
return __VA_ARGS__(); \
} else if (ngroups == 2) { \
int GQA_REGROUP = 2; \
return __VA_ARGS__(); \
} else if (ngroups == 9) { \
int GQA_REGROUP = 9; \
return __VA_ARGS__(); \
} else if (ngroups == 7) { \
int GQA_REGROUP = 7; \
return __VA_ARGS__(); \
} else if (ngroups == 5) { \
int GQA_REGROUP = 5; \
return __VA_ARGS__(); \
} else if (ngroups == 3) { \
int GQA_REGROUP = 3; \
return __VA_ARGS__(); \
} else if (ngroups == 29) { \
int GQA_REGROUP = 29; \
return __VA_ARGS__(); \
} else { \
int GQA_REGROUP = ngroups; \
return __VA_ARGS__(); \
} \
}()
#define PA_MTP_REUSEKV_SWITCH(pa_mtp , ...) \
[&] { \
if (pa_mtp == 8) { \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
} else if (pa_mtp == 16) { \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
} else if (pa_mtp == 20) { \
constexpr static int REUSE_KV_TIMES = 20; \
return __VA_ARGS__(); \
} else if (pa_mtp == 24) { \
constexpr static int REUSE_KV_TIMES = 24; \
return __VA_ARGS__(); \
} else if (pa_mtp == 28) { \
constexpr static int REUSE_KV_TIMES = 28; \
return __VA_ARGS__(); \
} else if (pa_mtp == 32) { \
constexpr static int REUSE_KV_TIMES = 32; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 0; \
return __VA_ARGS__(); \
} \
}()
#define MLA_REUSEKV_SWITCH(seqlen_q , ...) \
[&] { \
if (seqlen_q == 128) { \
constexpr static int REUSE_KV_TIMES = 128; \
return __VA_ARGS__(); \
} else if (seqlen_q == 64) { \
constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \
} else if (seqlen_q == 32) { \
constexpr static int REUSE_KV_TIMES = 32; \
return __VA_ARGS__(); \
} else if (seqlen_q == 16) { \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
} else if (seqlen_q == 12) { \
constexpr static int REUSE_KV_TIMES = 12; \
return __VA_ARGS__(); \
} else if (seqlen_q == 8) { \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
} else if (seqlen_q == 4) { \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (seqlen_q == 2) { \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else if (seqlen_q == 1) { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define MLA_MTP_REUSEKV_SWITCH(seqlen_q , ...) \
[&] { \
if (seqlen_q == 32) { \
constexpr static int REUSE_KV_TIMES = 32; \
return __VA_ARGS__(); \
} else if (seqlen_q == 24) { \
constexpr static int REUSE_KV_TIMES = 24; \
return __VA_ARGS__(); \
} else if (seqlen_q == 16) { \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
} \
}()
#define PERMUTE_DWORD_SWITCH(COND, DWORD_PER_TX, ...) \
[&] { \
if (COND) { \
constexpr static int DWORD_PER_TX = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int DWORD_PER_TX = 1; \
return __VA_ARGS__(); \
} \
}()
#define PERMUTE_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM < 256) { \
constexpr static int kHeadDim = 0; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
#define PA_HEADDIM_SWITCH(HEADDIMQ, HEADDIMV, ...) \
[&] { \
if (HEADDIMQ <= 32) { \
constexpr static int kHeadDimQ = 32; \
constexpr static int kHeadDimV = 32; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 64) { \
constexpr static int kHeadDimQ = 64; \
constexpr static int kHeadDimV = 64; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 96) { \
constexpr static int kHeadDimQ = 96; \
constexpr static int kHeadDimV = 96; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 128) { \
constexpr static int kHeadDimQ = 128; \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 160) { \
constexpr static int kHeadDimQ = 160; \
constexpr static int kHeadDimV = 160; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 192) { \
constexpr static int kHeadDimQ = 192; \
if (HEADDIMV <= 128) { \
constexpr static int kHeadDimV = 128; \
return __VA_ARGS__(); \
} else if (HEADDIMV <= 192) { \
constexpr static int kHeadDimV = 192; \
return __VA_ARGS__(); \
} \
} else if (HEADDIMQ <= 224) { \
constexpr static int kHeadDimQ = 224; \
constexpr static int kHeadDimV = 224; \
return __VA_ARGS__(); \
} else if (HEADDIMQ <= 256) { \
constexpr static int kHeadDimQ = 256; \
constexpr static int kHeadDimV = 256; \
return __VA_ARGS__(); \
} else if (HEADDIMQ == 576) { \
constexpr static int kHeadDimQ = 576; \
constexpr static int kHeadDimV = 512; \
return __VA_ARGS__(); \
} \
}()
#define PA_PAGEBLOCKSIZE_SWITCH(page_block_size, ...) \
[&] { \
if (page_block_size % 32 == 0 and page_block_size % 64 != 0) { \
constexpr static int kBlockN = 32; \
return __VA_ARGS__(); \
} else if (page_block_size % 64 == 0 and page_block_size % 128 != 0) { \
constexpr static int kBlockN = 64; \
return __VA_ARGS__(); \
} else if (page_block_size % 128 == 0) { \
constexpr static int kBlockN = 128; \
return __VA_ARGS__(); \
} \
}()
#define MLA_PARTITION_SIZE_SWITCH(partition_size, num_splits, ...) \
[&] { \
if (partition_size == 128 and num_splits > 1) { \
constexpr static int Partition_Size = 1; \
return __VA_ARGS__(); \
} else if (partition_size == 256 and num_splits > 1) { \
constexpr static int Partition_Size = 2; \
return __VA_ARGS__(); \
} else if (partition_size == 512 and num_splits > 1) { \
constexpr static int Partition_Size = 4; \
return __VA_ARGS__(); \
} else if (partition_size == 1024 and num_splits > 1) { \
constexpr static int Partition_Size = 8; \
return __VA_ARGS__(); \
} else if (partition_size == MLA_FIX_PARTITION and num_splits > 1) { \
constexpr static int Partition_Size = MLA_FIX_PARTITION; \
return __VA_ARGS__(); \
} else { \
constexpr static int Partition_Size = 0; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
#ifndef WAIT_H
#define WAIT_H
#define USE_PINGPANG_BUFFER
namespace flash {
__forceinline__ __device__ void wait_all_warp_arrived() {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n");
__builtin_amdgcn_sched_barrier(0);
}
template<bool Sync>
__forceinline__ __device__ void wait_all_buffer_data_arrived() {
__builtin_amdgcn_sched_barrier(0);
if constexpr (Sync) {
asm volatile("s_waitcnt vmcnt(0)\n\ts_barrier\n");
} else {
asm volatile("s_waitcnt vmcnt(0)\n");
}
__builtin_amdgcn_sched_barrier(0);
}
template<bool Sync>
__forceinline__ __device__ void wait_buffer_data_arrived(const int wait_count=0) {
__builtin_amdgcn_sched_barrier(0);
if constexpr (Sync) {
asm volatile("s_waitcnt vmcnt(%0)\n\ts_barrier\n":: "n"(wait_count));
} else {
asm volatile("s_waitcnt vmcnt(%0)\n":: "n"(wait_count));
}
__builtin_amdgcn_sched_barrier(0);
}
template<bool Sync>
__forceinline__ __device__ void wait_lds_data_arrived(const int wait_count=0) {
__builtin_amdgcn_sched_barrier(0);
if constexpr (Sync) {
asm volatile("s_waitcnt lgkmcnt(%0)\n\ts_barrier\n":: "n"(wait_count));
} else {
asm volatile("s_waitcnt lgkmcnt(%0)\n":: "n"(wait_count));
}
__builtin_amdgcn_sched_barrier(0);
}
} // namespace flash
template<const int COUNT>
__forceinline__ __device__ void buffer_load_lds_dwordx1_wait() {
asm volatile(
"s_waitcnt vmcnt(%0)\n\t"
"s_barrier\n"
:: "B"(COUNT)
:);
}
template<const int COUNT>
__forceinline__ __device__ void buffer_load_lds_dwordx1_wait_nosync() {
asm volatile(
"s_waitcnt vmcnt(%0)\n\t"
:: "B"(COUNT)
:);
}
template<int BLOCK_M, int BLOCK_N, int BLOCK_K>
inline __device__ void buffer_load_lds_dwordx1_wait() {
asm volatile("s_waitcnt vmcnt(0) \n\t"
"s_barrier");
}
__forceinline__ __device__ void s_barrier() {
asm volatile("s_barrier\n");
}
#define lgkmcnt_wait(X)\
__builtin_amdgcn_sched_barrier(0);\
asm volatile("s_waitcnt lgkmcnt(%0)": : "I"(X));\
__builtin_amdgcn_sched_barrier(0);
#define vmcnt_wait(X)\
__builtin_amdgcn_sched_barrier(0);\
asm volatile(\
"s_waitcnt vmcnt(%0)\n\t"\
"s_barrier\n"\
:: "I"(X)\
:);\
__builtin_amdgcn_sched_barrier(0);
#define vmcnt_wait_nosync(X)\
__builtin_amdgcn_sched_barrier(0);\
asm volatile(\
"s_waitcnt vmcnt(%0)\n\t"\
:: "I"(X)\
:);\
__builtin_amdgcn_sched_barrier(0);
#endif
This source diff could not be displayed because it is too large. You can view the blob instead.
/******************************************************************************************
* Copyright (c) 2025, Baohui.Fang, Yushun.Zhang, Chang.liu, Wenjian.Zhang, Jianbang.Xu
*****************************************************************************************/
#pragma once
#include "flash_fwd_launch_template_pa.h"
template<typename Kernel_traits, bool Is_causal, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mla_kernel(Params params) {
flash::compute_attn_splitkv_mla<Kernel_traits, Is_causal, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mla_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_mla_gfx938<Kernel_traits, Is_causal, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_fp8_mla_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_fp8_mla_gfx938<Kernel_traits, Is_causal, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, const bool Tail, typename Params>
void run_mla_splitkv_reduce(Params &params, hipStream_t stream) {
static_assert (Kernel_traits::kHeadDimV == 512 and "run_mla_splitkv_reduce only support splitkv for hdimv == 512");
using Element = typename Kernel_traits::Element;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
// rearrange narrow-storation params
Flash_fwd_mla_reduce_params reduce_params;
reduce_params.softmax_lse_ptr = params.softmax_lse_ptr;
reduce_params.oaccum_ptr = params.oaccum_ptr;
reduce_params.o_ptr = params.o_ptr;
reduce_params.cu_seqlens_k = params.cu_seqlens_k;
reduce_params.num_splits = params.num_splits;
reduce_params.partition_size = params.partition_size;
reduce_params.h = params.h;
reduce_params.seqlen_q = params.seqlen_q;
reduce_params.layout = params.layout;
reduce_params.topk_length = params.topk_length;
reduce_params.attn_sink = params.attn_sink;
reduce_params.extra_topk_length = params.extra_topk_length;
reduce_params.topk = params.topk;
reduce_params.extra_topk = params.extra_topk;
// reduce num_splits x [batch_size, num_head_q, seqlen_q, head_dim] output
if (params.num_splits > 1) {
dim3 block(256);
dim3 grid(params.b * params.h * params.seqlen_q, 4);
constexpr int MAX_NUM_SPLITS = 64;
if (params.num_splits > MAX_NUM_SPLITS) {
printf("\x1b[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel\033[0m\n", params.num_splits, MAX_NUM_SPLITS);
return;
}
if (params.num_splits == 2) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 2, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 4) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 4, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 8) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 8, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 16) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 16, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 32) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 32, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 64) {
flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(reduce_params);
} else {
printf("\x1b[31mnum_splits %d is not supported yet, and thus won't execute the kernel\033[0m\n", params.num_splits);
return;
}
} // if (params.num_splits > 1)
}
template<typename Kernel_traits>
void run_flash_splitkv_mla(Flash_fwd_mla_params &params, hipStream_t stream) {
// decide whether mls can be applied
bool mls_enabled = getArch() >= 938 and (std::getenv("MLA_NO_MLS") == nullptr);
// judge if mtp > 1, attention the causal mask ?
bool use_mtp = params.mtp > 1; // bool(params.seqlen_q > 1 and !params.seqlenq_ngroups_swapped);
// judge whether run with 16x32 tile
bool use_tile_16x32 = params.seqlen_q <= 16/*16x32 tile*/ and std::getenv("MLA_USE_TILE32X32") == nullptr/*env control*/;
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
/*every 2 in seqlen dimension requires 1024 KB lds, and thus, max lds should be limited within 64 * 2 = 128*/
const size_t smem_for_acc = int((params.seqlen_q + 1) / 2) * 2/*>= 的偶数*/ * WARP_NUM * Kernel_traits::kBlockK * sizeof(float);
const size_t _q_smem_size = Kernel_traits::STAGES == 2 ? Kernel_traits::q_smem_size: Kernel_traits::q_smem_size * (Kernel_traits::kHeadDim / Kernel_traits::kBlockK/*576需要加载几次*/) / 2; // 除以 2 是因为只需要用 16x32 的 lds, 节约用量
const size_t q_smem_size = use_tile_16x32 ? _q_smem_size / 34 * 32: _q_smem_size;
const size_t k_smem_size = use_tile_16x32 ? Kernel_traits::k_smem_size / 34 * 32 * WARP_NUM/*16x32 tile no padding*/: Kernel_traits::k_smem_size * WARP_NUM/*32x32 tile use padding 32 -> 34*/;
const size_t v_smem_size = Kernel_traits::v_smem_size * WARP_NUM;
const size_t smem_for_gemm = Kernel_traits::STAGES == 2 ? std::max(q_smem_size, std::max(k_smem_size, v_smem_size)): q_smem_size + std::max(k_smem_size, v_smem_size);
const size_t required_smem_size = std::max(smem_for_acc, smem_for_gemm + smem_for_max);
const size_t smem_size = mls_enabled ? (params.b >= params.cu_count ? 64 * 1024: 32 * 1024): (use_tile_16x32 ? 32 * 1024: required_smem_size);
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// acquire kernel fuction
void (*kernel)(Flash_fwd_mla_params);
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld\nuse_tile_16x32: %d\n",
smem_for_max, smem_for_acc, q_smem_size, k_smem_size, v_smem_size, smem_for_gemm, required_smem_size, smem_size, use_tile_16x32);
printf("dispatch grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
}
// decide which kernel function
if (use_tile_16x32) {
// mtp is not considered yet for 16x32 tile
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
constexpr int M_MMAC_COUNT = 1; // only need to compute 16x576 @ 576xseqlenk
grid.x = num_m_block * HEADDIM_V_SPLIT;
MLA_PARTITION_SIZE_SWITCH(params.partition_size, params.num_splits, [&] {
if (use_mtp) {
if (!params.seqlenq_ngroups_swapped) {
kernel = mls_enabled
? &flash_fwd_splitkv_mla_gfx938_kernel<Kernel_traits, true/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>
: &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>;
} else {
MLA_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = mls_enabled
? &flash_fwd_splitkv_mla_gfx938_kernel<Kernel_traits, true/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>
: &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>;
});
}
} else {
MLA_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = mls_enabled
? &flash_fwd_splitkv_mla_gfx938_kernel<Kernel_traits, false/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>
: &flash_fwd_splitkv_mla_kernel<Kernel_traits, false/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>;
});
}
});
} else {
// need to compute 32x576 @ 576xseqlenk, split-D is used to reduce vgpr spill
constexpr int HEADDIM_V_SPLIT = 2;
grid.x = num_m_block * HEADDIM_V_SPLIT;
// fixed num_splits for mla
if (params.partition_size == MLA_FIX_PARTITION and params.num_splits > 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 1 or use_mtp, M_MMAC_COUNT, [&] {
if (use_mtp) { // MTP > 1
if (!params.seqlenq_ngroups_swapped) {
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, true/*Split*/, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, MLA_FIX_PARTITION>;
} else {
MLA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&]{
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, true/*Split*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, MLA_FIX_PARTITION>;
});
}
} else {
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, false/*Is_causal*/, true/*Split*/, M_MMAC_COUNT, 0/*REUSE_KV_TIMES*/, HEADDIM_V_SPLIT, MLA_FIX_PARTITION>;
}
});
} else {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 1 or use_mtp, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
if (use_mtp) { // MTP > 1
if (!params.seqlenq_ngroups_swapped) {
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, Split, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, 0>;
} else {
MLA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&]{
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, true/*Is_causal*/, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, 0>;
});
}
} else {
kernel = &flash_fwd_splitkv_mla_kernel<Kernel_traits, false/*Is_causal*/, Split, M_MMAC_COUNT, 0/*REUSE_KV_TIMES*/, HEADDIM_V_SPLIT, 0>;
}
});
});
}
}
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce
// run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream);
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_splitkv_dispatch(Flash_fwd_mla_params &params, hipStream_t stream) {
// 是否编译多个 page block size, 代码会膨胀
#ifdef MLA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
run_flash_splitkv_mla<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_mla<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
/*以前的 STAGES = 1, 适用于 "不做 splitD, seqlen<=16, 把 Q 不读到寄存器, 读到 lds, 且只塞 16x576 个 Half 到 lds" 的情况, 当时情况可行, 但后续性能被 16x32 tile 取代, 可以作为紧急时期的备选方案*/
#endif
}
template<typename Kernel_traits>
void run_flash_splitkv_fp8_mla(Flash_fwd_mla_params &params, hipStream_t stream) {
// judge if mtp > 1, attention the causal mask ?
bool use_mtp = params.mtp > 1; // bool(params.seqlen_q > 1 and !params.seqlenq_ngroups_swapped);
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// acquire kernel fuction
void (*kernel)(Flash_fwd_mla_params);
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// decide shared memory size
const size_t smem_size = 32 * 1024;
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_size: %ld\n", smem_size);
printf("dispatch grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
}
// decide which kernel function
// mtp is not considered yet for 16x32 tile
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
constexpr int M_MMAC_COUNT = 1; // only need to compute 16x576 @ 576xseqlenk
grid.x = num_m_block * HEADDIM_V_SPLIT;
MLA_PARTITION_SIZE_SWITCH(params.partition_size, params.num_splits, [&] {
if (use_mtp) {
kernel = &flash_fwd_splitkv_fp8_mla_gfx938_kernel<Kernel_traits, true/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>;
} else {
// only tp8 is supported yet
kernel = &flash_fwd_splitkv_fp8_mla_gfx938_kernel<Kernel_traits, false/*Is_causal*/, bool(Partition_Size > 0), M_MMAC_COUNT, 16, HEADDIM_V_SPLIT, Partition_Size, Flash_fwd_mla_params>;
}
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce
// run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream);
}
template<typename T, int Headdim, int HeaddimV>
void run_fp8_mla_fwd_splitkv_dispatch(Flash_fwd_mla_params &params, hipStream_t stream) {
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fp8_mla<Flash_fwd_kernel_traits<Headdim, HeaddimV, 16, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
}
template<typename T, int Headdim, int HeaddimV>
void run_fp8_mla_convert_q_to_fp8_dispatch(Flash_fwd_mla_params &params, hipStream_t stream) {
BOOL_SWITCH(params.total_blocks > 256, is_persistent, [&]{
dim3 grid(is_persistent ? 512: params.total_blocks);
dim3 block(64, 1, 1);
flash_mla_convert_query_to_fp8_kernel<T, fp8_e4m3, is_persistent><<<grid, block, 8192, stream>>>(
reinterpret_cast<fp8_e4m3*>(params.o_ptr),
reinterpret_cast<T*>(params.qv_ptr),
reinterpret_cast<T*>(params.q_ptr),
params.total_blocks,
params.o_head_stride,
params.qv_head_stride,
params.q_head_stride,
params.qv_row_stride,
params.q_row_stride,
params.h
);
});
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_prefix_prefill_dispatch(Flash_fwd_mla_params &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("MLA_PREFILL_NO_MLS") == nullptr) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dim3 dimGrid;
dimBlock.x = min((kBlockM) / (WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
using Kernel_traits = Flash_fwd_kernel_traits<Headdim, HeaddimV, kBlockM, kBlockN, 32/* kBlockK */, WARP_M, 64/* WARP_N */, 2/* STAGES */, false, false, T>;
constexpr int REUSE_KV = 1;
constexpr bool Is_dropout = false;
if (params.is_causal) {
dimGrid.x = (params.seqlen_q + 2 * kBlockM - 1) / (2 * kBlockM);
} else {
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
}
dimGrid.y = (params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
const bool is_mtp = (params.mtp != 0) ? true : false;
if (is_mtp == false) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_prefill_kernel_gfx938<Kernel_traits, true/*Is_training*/, Is_dropout, true/* Is_prefix */,Is_causal, IsEvenMNConst, true/*Is_even_K*/, false/*Return_softmax*/, false/* Is_MTP */, 0, Flash_fwd_mla_params>
<<<dimGrid, dimBlock, 32 * 1024, stream>>>(params); // 暂时保留MTP模板参数,实测无误后删除
});
} else {
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// flash::flash_fwd_mla_decode_kernel_gfx938<Kernel_traits, true/*Is_training*/, Is_dropout, true/* Is_prefix */,Is_causal, IsEvenMNConst, true/*Is_even_K*/, false/*Return_softmax*/, true/* Is_MTP */, 0, Flash_fwd_mla_params>
// <<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
// });
}
} else {
if (params.b * params.h >= params.cu_count) {
constexpr int kBlockM = 32; // vgpr spill 280+ when WARP_M = 32
constexpr int kBlockN = 128;
constexpr int parallel = 2;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // todo: streamkv
params.q_blocks = (params.q_blocks + 1) / parallel;
dim3 grid(parallel, params.h, params.b); // todo: regroup qheads into seqlen_q and dispatch less blocks
dim3 block(256, 1, 1);
flash_fwd_mla_prefix_prefill_kernel<Headdim, HeaddimV, kBlockM, kBlockN, true/*Is_prefix*/, false/*Is_causal*/, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
} else {
constexpr int kBlockM = 64;
constexpr int kBlockN = 128;
params.q_blocks = (params.seqlen_q + 2 * kBlockM - 1) / (2 * kBlockM);
params.total_blocks = params.b * params.h * params.q_blocks;
dim3 grid(params.cu_count); // (params.q_blocks, params.h, params.b) when no fix, corresponding to #elif 0
dim3 block(512, 1, 1);
flash_fwd_mla_prefix_prefill_fix_kernel<Headdim, HeaddimV, kBlockM, kBlockN, true/*Is_prefix*/, false/*Is_causal*/, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
}
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_dispatch(Flash_fwd_mla_params &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("MLA_DP_DECODE_NO_MLS") == nullptr/* and params.b >= 16 */) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dim3 dimGrid;
dimBlock.x = min((kBlockM) / (WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
using Kernel_traits = Flash_fwd_kernel_traits<Headdim, HeaddimV, kBlockM, kBlockN, 32/* kBlockK */, WARP_M, 64/* WARP_N */, 2/* STAGES */, false, false, T>;
constexpr int REUSE_KV = 1;
constexpr bool Is_dropout = false;
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
dimGrid.y = 1;//(params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// flash::flash_fwd_mla_decode_kernel_gfx938<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params>
// <<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
// });
});
} else {
if (params.b * params.h >= params.cu_count / 2) {
constexpr int kBlockM = 32; // vgpr spill 280+ when WARP_M = 32
constexpr int kBlockN = 128;
constexpr int parallel = 2;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // todo: streamkv
params.q_blocks = (params.q_blocks + 1) / parallel;
dim3 grid(parallel, params.h, params.b); // todo: regroup qheads into seqlen_q and dispatch less blocks
dim3 block(256, 1, 1);
BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
flash_fwd_mla_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
});
} else {
constexpr int kBlockM = 64;
constexpr int kBlockN = 128;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // used in kernels
params.total_blocks = params.b * params.h * params.q_blocks;
// dim3 grid(params.q_blocks, params.h, params.b);
dim3 grid(params.cu_count);
dim3 block(512, 1, 1);
BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
flash_fwd_mla_fix_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
});
}
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_dispatch_dsa(Flash_fwd_mla_params_dsa &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("MLA_DP_DECODE_NO_MLS") == nullptr/* and params.b >= 16 */) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dim3 dimGrid;
dimBlock.x = min((kBlockM) / (WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
using Kernel_traits = Flash_fwd_kernel_traits<Headdim, HeaddimV, kBlockM, kBlockN, 32/* kBlockK */, WARP_M, 64/* WARP_N */, 2/* STAGES */, false, false, T>;
constexpr int REUSE_KV = 1;
constexpr bool Is_dropout = false;
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
dimGrid.y = 1;//(params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
if(params.seqlen_q == 128){
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_decode<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
});
});
}
else{
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
});
});
}
} else {
if (params.b * params.h >= params.cu_count / 2) {
constexpr int kBlockM = 32; // vgpr spill 280+ when WARP_M = 32
constexpr int kBlockN = 128;
constexpr int parallel = 2;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // todo: streamkv
params.q_blocks = (params.q_blocks + 1) / parallel;
dim3 grid(parallel, params.h, params.b); // todo: regroup qheads into seqlen_q and dispatch less blocks
dim3 block(256, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
} else {
constexpr int kBlockM = 64;
constexpr int kBlockN = 128;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // used in kernels
params.total_blocks = params.b * params.h * params.q_blocks;
// dim3 grid(params.q_blocks, params.h, params.b);
dim3 grid(params.cu_count);
dim3 block(512, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_fix_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
}
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_dispatch_dsa_prefill_nopage(Flash_fwd_mla_params_dsa &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("MLA_DP_DECODE_NO_MLS") == nullptr/* and params.b >= 16 */) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dim3 dimGrid;
dimBlock.x = min((kBlockM) / (WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
using Kernel_traits = Flash_fwd_kernel_traits<Headdim, HeaddimV, kBlockM, kBlockN, 32/* kBlockK */, WARP_M, 64/* WARP_N */, 2/* STAGES */, false, false, T>;
constexpr int REUSE_KV = 1;
constexpr bool Is_dropout = false;
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
dimGrid.y = 1;//(params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
if(params.seqlen_q == 128){
printf("not support");
// BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// flash::flash_fwd_mla_decode_kernel_gfx938_dsa_decode<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
// <<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
// });
// });
}
else{
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 32 * 1024, stream>>>(params);
});
});
}
} else {
if (params.b * params.h >= params.cu_count / 2) {
constexpr int kBlockM = 32; // vgpr spill 280+ when WARP_M = 32
constexpr int kBlockN = 128;
constexpr int parallel = 2;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // todo: streamkv
params.q_blocks = (params.q_blocks + 1) / parallel;
dim3 grid(parallel, params.h, params.b); // todo: regroup qheads into seqlen_q and dispatch less blocks
dim3 block(256, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
} else {
constexpr int kBlockM = 64;
constexpr int kBlockN = 128;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // used in kernels
params.total_blocks = params.b * params.h * params.q_blocks;
// dim3 grid(params.q_blocks, params.h, params.b);
dim3 grid(params.cu_count);
dim3 block(512, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_fix_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
}
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params, hipStream_t stream) {
int gcn_arch = getArch();
if (gcn_arch >= 938 and std::getenv("MLA_DP_DECODE_NO_MLS") == nullptr/* and params.b >= 16 */) {
constexpr int kBlockM = 64;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dim3 dimGrid;
dimBlock.x = min((kBlockM) / (WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
using Kernel_traits = Flash_fwd_kernel_traits<Headdim, HeaddimV, kBlockM, kBlockN, 32/* kBlockK */, WARP_M, 64/* WARP_N */, 2/* STAGES */, false, false, T, T>;
constexpr int REUSE_KV = 1;
constexpr bool Is_dropout = false;
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
dimGrid.y = 1;//(params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
if(params.num_splits == 1){
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
}
else if(params.num_splits != 0){
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
dimGrid.y = params.num_splits;//(params.h == params.h_k) ? params.h: params.h / REUSE_KV;
dimGrid.z = params.b;
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream);
}
else{
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
});
});
}
} else {
if (params.b * params.h >= params.cu_count / 2) {
constexpr int kBlockM = 32; // vgpr spill 280+ when WARP_M = 32
constexpr int kBlockN = 128;
constexpr int parallel = 2;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // todo: streamkv
params.q_blocks = (params.q_blocks + 1) / parallel;
dim3 grid(parallel, params.h, params.b); // todo: regroup qheads into seqlen_q and dispatch less blocks
dim3 block(256, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
} else {
constexpr int kBlockM = 64;
constexpr int kBlockN = 128;
params.q_blocks = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); // used in kernels
params.total_blocks = params.b * params.h * params.q_blocks;
// dim3 grid(params.q_blocks, params.h, params.b);
dim3 grid(params.cu_count);
dim3 block(512, 1, 1);
// BOOL_SWITCH(params.mtp > 1, Is_causal, [&]{
// flash_fwd_mla_fix_kernel<Headdim, HeaddimV, kBlockM, kBlockN, false/*Is_prefix*/, Is_causal, T, float, Flash_fwd_mla_params><<<grid, block, 0, stream>>>(params);
// });
}
}
}
/******************************************************************************************
* Copyright (c) 2025, Baohui.Fang, Yushun.Zhang, Chang.liu, Wenjian.Zhang, Jianbang.Xu
*****************************************************************************************/
#pragma once
#include "config.h"
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
#include "flash_singleton.h"
#include "assert.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_Varlen, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, bool Is_GQA, bool Is_softcap, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, bool Append_KV>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, /*Is_training*/false, Is_dropout, Is_causal, Is_Varlen, Is_local, Is_even_K, Return_softmax, Has_alibi, Is_GQA, Is_softcap, Split, M_MMAC_COUNT, REUSE_KV_TIMES, Append_KV>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Has_alibi, bool Is_GQA, bool Is_softcap, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, bool Append_KV>
__global__ void __launch_bounds__(256,1) flash_fwd_splitkv_int8_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv_int8<Kernel_traits, /*Is_training*/false, Is_dropout, Is_causal, Is_local, Is_even_K, Return_softmax, Has_alibi, Is_GQA, Is_softcap, Split, M_MMAC_COUNT, REUSE_KV_TIMES, Append_KV>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, bool Is_local, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_tile16x32_kernel(Params params) {
flash::compute_attn_splitkv_tile16x32<Kernel_traits, Is_causal, Is_Varlen, Split, Is_local, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_Varlen, bool Split, int M_MMAC_COUNT, int REUSE_KV_TIMES, int HEADDIM_V_SPLIT, int Partition_Size, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_gfx938_kernel(Params params) {
flash::compute_attn_splitkv_gfx938<Kernel_traits, Is_causal, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>(params);
}
template<typename Kernel_traits, const bool Tail, typename Params>
void run_splitkv_reduce(Params &params, hipStream_t stream) {
// now, only headdim 128/512 support splitkv, since shuffle kernel doesn't support other headdims
if constexpr (Kernel_traits::kHeadDimV == 128 or Kernel_traits::kHeadDimV == 512 or Kernel_traits::kHeadDimV == 64) {
// reduce num_splits x [batch_size, num_head_q, seqlen_q, head_dim] output
if (params.num_splits > 1) {
dim3 block(64);
dim3 grid(params.b * params.h * params.seqlen_q);
constexpr int MAX_NUM_SPLITS = 1024;
if (params.num_splits > MAX_NUM_SPLITS) {
printf("\x1b[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel\033[0m\n", params.num_splits, MAX_NUM_SPLITS);
return;
}
using Element = typename Kernel_traits::Element;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
if (params.num_splits == 2) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 2, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 4) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 4, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 8) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 8, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 16) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 16, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 32) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 32, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 64) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 128) {
block.x = params.num_splits;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 128, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 256) {
block.x = params.num_splits;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 256, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 512) {
block.x = params.num_splits;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 512, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 1024) {
block.x = params.num_splits;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 1024, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
}
/*kernels above can be extremely optimized when unroll is true*/
else if (params.num_splits > 512) {
block.x = 1024;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 1024, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits > 256) {
block.x = 512;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 512, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits > 128) {
block.x = 256;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 256, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits > 64) {
block.x = 128;
flash_fwd_splitkv_reduce_kernel_split128<SplitkvAccumType, Element, 128, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits > 32) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits > 16) {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 32, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else {
flash_fwd_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
}
} // if (params.num_splits > 1)
} // if (kHeadDim == 128)
}
template<typename Kernel_traits, const bool Tail, typename Params>
void run_splitkv_reduce_varlen(Params &params, hipStream_t stream) {
// now, only headdim 128/512 support splitkv, since shuffle kernel doesn't support other headdims
if constexpr (Kernel_traits::kHeadDimV == 128 or Kernel_traits::kHeadDimV == 512 or Kernel_traits::kHeadDimV == 64) {
// reduce num_splits x [batch_size, num_head_q, seqlen_q, head_dim] output
if (params.num_splits > 1) {
dim3 block(64);
dim3 grid(params.h * params.ngroups, params.b); /*total_q 是会变化的, 不能放在这里用于启动 cuda-graph*/
constexpr int MAX_NUM_SPLITS = 64;
if (params.num_splits > MAX_NUM_SPLITS) {
printf("\x1b[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel\033[0m\n", params.num_splits, MAX_NUM_SPLITS);
return;
}
using Element = typename Kernel_traits::Element;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
if (params.num_splits == 2) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 2, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 4) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 4, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 8) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 8, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 16) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 16, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 32) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 32, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else if (params.num_splits == 64) {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 64, true/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
} else {
flash_fwd_splitkv_reduce_varlen_kernel<SplitkvAccumType, Element, 64, false/*unroll*/, Tail, Kernel_traits::kHeadDimV><<<grid, block, 0, stream>>>(params);
}
} // if (params.num_splits > 1)
} // if (kHeadDim == 128)
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_misalign = (params.seqlen_q >= 16 or Kernel_traits::kHeadDimV == 512) ? Kernel_traits::kHeadDimV: (Kernel_traits::kHeadDimV + 4)/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
const size_t smem_for_acc = int((params.seqlen_q + 1) / 2) * 2 * WARP_NUM * smem_misalign * sizeof(float);
const size_t smem_for_gemm = std::max(std::max(Kernel_traits::q_smem_size, Kernel_traits::k_smem_size * WARP_NUM), Kernel_traits::v_smem_size * WARP_NUM);
const size_t required_smem_size = std::max(smem_for_acc, smem_for_gemm + smem_for_max);
/*
for gfx936,
2 waves per SIMD is better than 1 waves per SIMD;
3 waves per SIMD will bring performance degradation
for gfx928,
> 1 waves per SIMD will significantly increase the latency of buffer-load, blocked at TA
*/
hipDeviceProp_t props;
auto hip_result = hipGetDeviceProperties(&props, 0);
#ifdef ROCM_5_7
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
#endif
const size_t smem_size = gcn_arch > 928 ? required_smem_size: size_t(64 * 1024);
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | smem_for_gemm: %ld | needed smem_size: %ld | smem_size: %ld\n", smem_for_max, smem_for_acc, smem_for_gemm, required_smem_size, smem_size);
}
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
// dim3 grid(num_m_block, params.h, params.b);
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool has_alibi = (params.alibi_slopes_ptr not_eq nullptr);
BOOL_SWITCH(has_alibi, Has_Alibi, [&] {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 1, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr bool IsEvenMNConst = false;
constexpr bool Is_local = false;
void (*kernel)(Flash_fwd_params);
if (params.mtp > 1) {
kernel = &flash_fwd_splitkv_kernel<Kernel_traits, false, true/*Is_causal*/, false/*Is_Varlen*/, Is_local, IsEvenMNConst && !Is_local && true && Kernel_traits::kHeadDim <= 256, true, false, Has_Alibi, false/*is_gqa*/, false/*Is_softcap*/, Split, M_MMAC_COUNT, 0, false/*Append_KV*/>;
} else {
if constexpr (Kernel_traits::kHeadDim == 128 and Kernel_traits::kHeadDimV == 128) {
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_kernel<Kernel_traits, false, false/*Is_causal*/, Is_Varlen, Is_local, IsEvenMNConst && !Is_local && true && Kernel_traits::kHeadDim <= 256, true, false, Has_Alibi, false/*is_gqa*/, false/*Is_softcap*/, Split, M_MMAC_COUNT, REUSE_KV_TIMES, false/*Append_KV*/>;
});
});
} else { // non-headdim128 cases
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_kernel<Kernel_traits, false, false/*Is_causal*/, false/*Is_Varlen*/, Is_local, IsEvenMNConst && !Is_local && true && Kernel_traits::kHeadDim <= 256, true, false, Has_Alibi, false/*is_gqa*/, false/*Is_softcap*/, Split, M_MMAC_COUNT, REUSE_KV_TIMES, false/*Append_KV*/>;
});
}
}
kernel<<<grid, nthread, smem_size, stream>>>(params);
});
});
});
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_for_acc = int((params.seqlen_q + 1) / 2) * 2/*reserved for M_MMAC_COUNT = 2*/ * WARP_NUM * Kernel_traits::kBlockK * sizeof(float);
const size_t q_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockM * Kernel_traits::kBlockK * sizeof(half_t);
const size_t k_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockK * Kernel_traits::kWaveN * sizeof(half_t) * WARP_NUM;
const size_t v_smem_size = k_smem_size;
const size_t smem_for_gemm = std::max(q_smem_size, std::max(k_smem_size, v_smem_size));
const size_t required_smem_size = std::max(smem_for_acc, std::max(smem_for_gemm, smem_for_max));
hipDeviceProp_t props;
auto hip_result = hipGetDeviceProperties(&props, 0);
#ifdef ROCM_5_7
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
#endif
const size_t smem_size = gcn_arch > 928 ? size_t(std::max<size_t>(32 * 1024, required_smem_size)): size_t(64 * 1024);
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | smem_for_gemm: %ld | needed smem_size: %ld | smem_size: %ld\n", smem_for_max, smem_for_acc, smem_for_gemm, required_smem_size, smem_size);
}
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.window_size_left > 0 and params.window_size_right >= 0) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, false, true/*Is_local*/, M_MMAC_COUNT, 0, HEADDIM_V_SPLIT, 0>;
});
} else if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
} else {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_tile16x32_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, false/*Is_local*/, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
}
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd_mha(Flash_fwd_params &params, hipStream_t stream) {
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
kernel = &flash_fwd_splitkv_mha_kernel<Kernel_traits, false/*Is_causal*/, false/*Split*/, 1/*HEADDIM_V_SPLIT*/, 0/*Partition_Size*/>;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h, params.b);
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, 0, stream>>>(params);
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd_gfx938(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_for_acc = int((params.seqlen_q + 1) / 2) * 2/*reserved for M_MMAC_COUNT = 2*/ * WARP_NUM * Kernel_traits::kBlockK * sizeof(float);
const size_t q_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockM * Kernel_traits::kBlockK * sizeof(half_t);
const size_t k_smem_size = Kernel_traits::STAGES * Kernel_traits::kBlockK * Kernel_traits::kWaveN * sizeof(half_t) * WARP_NUM;
const size_t v_smem_size = k_smem_size;
const size_t smem_for_gemm = std::max(q_smem_size, std::max(k_smem_size, v_smem_size));
const size_t required_smem_size = std::max(smem_for_acc, std::max(smem_for_gemm, smem_for_max));
const size_t smem_size = size_t(std::max<size_t>(32 * 1024, required_smem_size));
if (std::getenv("FA_DEBUG") != nullptr) {
printf("smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld\n",
smem_for_max, smem_for_acc, q_smem_size, k_smem_size, v_smem_size, smem_for_gemm, required_smem_size, smem_size);
}
// compute block partition along seqlen_q direction
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// decide task dispatch logic
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
// acquire kernel fuction
void (*kernel)(Flash_fwd_params);
constexpr int HEADDIM_V_SPLIT = 1; // no need to split-D
grid.x = num_m_block * HEADDIM_V_SPLIT;
BOOL_SWITCH(params.q_batch_stride == 0, Is_Varlen, [&] {
if (params.mtp == 1) {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_gfx938_kernel<Kernel_traits, false/*Is_causal*/, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
} else {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 16, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr int Partition_Size = 0; // pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_gfx938_kernel<Kernel_traits, true/*Is_causal*/, Is_Varlen, Split, M_MMAC_COUNT, REUSE_KV_TIMES, HEADDIM_V_SPLIT, Partition_Size>;
});
});
});
}
});
// Kernel execution
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
kernel<<<grid, nthread, smem_size, stream>>>(params);
// reduce PA v2
if (params.q_batch_stride == 0) {
run_splitkv_reduce_varlen<Kernel_traits, false/*Tail*/>(params, stream);
} else {
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
}
template<typename T, int Headdim, int HeaddimV>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream) {
// decide whether commonly used headdims
const bool is_commonly_used = params.d % 64 == 0 and params.d_value % 64 == 0/*prefetch 2 32x32 blocks along headdim*/;
// For latest archs, mls can be applied for headdim 128
if ((getArch() >= 938) and std::getenv("PA_NO_MLS") == nullptr and is_commonly_used) {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fwd_gfx938<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
#endif
}
// For MHA-fma, headdim = 128
else if (params.seqlen_q == 1 and !params.seqlenq_ngroups_swapped and Headdim == 128 and HeaddimV == 128 and std::getenv("PA_USE_FMA") != nullptr) {
constexpr int kBlockN = 128;
run_flash_splitkv_fwd_mha<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32/*kBlockM*/, kBlockN, 32/*kBlockK*/, 32, 32, 2/*STAGES*/, false, false, T, float> >(params, stream);
}
else if (params.seqlen_q <= 32/*16x32 tile*/ and not params.splitkv_use_fp32_as_accum and params.alibi_slopes_ptr == nullptr and std::getenv("PA_USE_TILE32X32") == nullptr and is_commonly_used) {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 32 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
run_flash_splitkv_fwd_tile16x32<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
run_flash_splitkv_fwd_tile16x32<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, 2/*STAGES*/, false, false, T, T> >(params, stream);
#endif
} else {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 64 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
// regardless of params.splitkv_use_fp32_as_accum to reduce volume of FA whl
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
if (params.splitkv_use_fp32_as_accum) {
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, float> >(params, stream);
} else {
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, HeaddimV, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T> >(params, stream);
}
#endif
}
}
template<typename Kernel_traits>
void run_int8_flash_splitkv_fwd(Flash_fwd_params &params, hipStream_t stream) {
constexpr int WARP_NUM = Kernel_traits::kBlockN / Kernel_traits::kWaveN;
const size_t smem_for_max = std::max(WARP_NUM * Kernel_traits::kWaveM * sizeof(float), size_t(1024));
const size_t smem_misalign = (params.seqlen_q >= 16 or Kernel_traits::kHeadDimV == 512) ? Kernel_traits::kHeadDimV: (Kernel_traits::kHeadDimV + 4)/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
const size_t smem_for_acc = int((params.seqlen_q + 1) / 2) * 2 * WARP_NUM * smem_misalign * sizeof(float);
const size_t smem_for_gemm = std::max(std::max(Kernel_traits::q_smem_size, Kernel_traits::k_smem_size * WARP_NUM), Kernel_traits::v_smem_size * WARP_NUM);
#if defined(KVCACHE_USE_4STAGES_PINGPANG) // 4 倍 pingpang buffer 已经 32KB 了, 需要跟 max 共享
const size_t required_smem_size = std::max(smem_for_max, std::max(smem_for_acc, smem_for_gemm));
#else
const size_t required_smem_size = std::max(smem_for_acc, smem_for_gemm + smem_for_max);
#endif
/*
for gfx936,
2 waves per SIMD is better than 1 waves per SIMD;
3 waves per SIMD will bring performance degradation
for gfx928,
> 1 waves per SIMD will significantly increase the latency of buffer-load, blocked at TA
*/
hipDeviceProp_t props;
auto hip_result = hipGetDeviceProperties(&props, 0);
#ifdef ROCM_5_7
int gcn_arch = props.gcnArch;
#else
std::string gcn_arch_name(props.gcnArchName);
int gcn_arch = std::stoi(gcn_arch_name.substr(3, 3));
#endif
const size_t smem_size = gcn_arch > 928 ? required_smem_size: size_t(48 * 1024);
// printf("smem_for_max: %ld | smem_for_acc: %ld | smem_for_gemm: %ld | needed smem_size: %ld | smem_size: %ld\n", smem_for_max, smem_for_acc, smem_for_gemm, required_smem_size, smem_size);
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int nthread = Kernel_traits::kBlockN / Kernel_traits::kWaveN * 64;
// dim3 grid(num_m_block, params.h, params.b);
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.h, params.num_splits > 1 ? params.b * params.h : params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
const bool has_alibi = (params.alibi_slopes_ptr not_eq nullptr);
// judge if mtp > 1, attention the causal mask ?
bool use_mtp = bool(params.seqlen_q > 1 and !params.seqlenq_ngroups_swapped);
#ifdef BUILD_ASM
// select most likely used kernels to analyze instruction flow
M_MMAC_COUNT_SWITCH(params.seqlen_q > 1, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
void (*kernel)(Flash_fwd_params);
if (use_mtp) { // MTP > 1
kernel = &flash_fwd_splitkv_int8_kernel<Kernel_traits, false, true/*Is_causal*/, false, false, true, false, false, false/*is_gqa*/, false/*Is_softcap*/, Split/*Split*/, M_MMAC_COUNT, 0, false/*Append_KV*/>;
} else {
REUSEKV_SWITCH(params.seqlen_q, [&] { // MTP = 1, can reuse
kernel = &flash_fwd_splitkv_int8_kernel<Kernel_traits, false, false/*Is_causal*/, false, false, true, false, false, false/*is_gqa*/, false/*Is_softcap*/, Split/*Split*/, M_MMAC_COUNT, REUSE_KV_TIMES, false/*Append_KV*/>;
});
}
kernel<<<grid, nthread, smem_size, stream>>>(params);
});
});
#else
bool is_local_mask = bool((params.window_size_left >= 0 || params.window_size_right >= 0) && !(use_mtp));
if (is_local_mask) {printf("\x1b[31mSliding window attention for Paged-Atention is not supported yet!\033[0m\n");}
BOOL_SWITCH(has_alibi, Has_Alibi, [&] {
M_MMAC_COUNT_SWITCH(params.seqlen_q > 1, M_MMAC_COUNT, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
constexpr bool IsEvenMNConst = false;
constexpr bool Is_local = false;
void (*kernel)(Flash_fwd_params);
if (use_mtp) {
kernel = &flash_fwd_splitkv_int8_kernel<Kernel_traits, false, true/*Is_causal*/, Is_local, IsEvenMNConst && !Is_local && true && Kernel_traits::kHeadDim <= 256, true, false, Has_Alibi, false/*is_gqa*/, false/*Is_softcap*/, Split, M_MMAC_COUNT, 0, false/*Append_KV*/>;
} else {
REUSEKV_SWITCH(params.seqlen_q, [&] {
kernel = &flash_fwd_splitkv_int8_kernel<Kernel_traits, false, false/*Is_causal*/, Is_local, IsEvenMNConst && !Is_local && true && Kernel_traits::kHeadDim <= 256, true, false, Has_Alibi, false/*is_gqa*/, false/*Is_softcap*/, Split, M_MMAC_COUNT, REUSE_KV_TIMES, false/*Append_KV*/>;
});
}
#if defined(FA_KERNEL_TIMER)
hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));
HIP_CHECK(hipEventRecord(start, 0));
#endif
kernel<<<grid, nthread, smem_size, stream>>>(params);
#if defined(FA_KERNEL_TIMER)
HIP_KERNEL_LAUNCH_CHECK();
HIP_CHECK(hipDeviceSynchronize());
HIP_CHECK(hipEventRecord(stop, 0)) ;
HIP_CHECK(hipEventSynchronize(stop));
float ave_time;
HIP_CHECK(hipEventElapsedTime(&ave_time,start, stop));
printf("run_flash_splitkv_fwd: %f\n", ave_time * 1000);
#endif
});
});
});
#endif
run_splitkv_reduce<Kernel_traits, true/*Tail*/>(params, stream);
}
template<typename T, int Headdim, int HeaddimV>
void run_int8_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream) {
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if (params.page_block_size % 64 != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
PA_PAGEBLOCKSIZE_SWITCH(params.page_block_size, [&]{
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
// regardless of params.splitkv_use_fp32_as_accum to reduce volume of FA whl
run_int8_flash_splitkv_fwd<Flash_fwd_kernel_traits<128, 128, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T, int8_t> >(params, stream);
});
#else
constexpr int kBlockN = 128;
if (params.page_block_size % kBlockN != 0) { printf("\x1b[31mPage block size %d is not supported yet!\033[0m\n", params.page_block_size); return; }
constexpr int STAGES = (Headdim == 128) ? 3: (Headdim == 32 ? 1: 2);
if (params.splitkv_use_fp32_as_accum) {
run_int8_flash_splitkv_fwd<Flash_fwd_kernel_traits<128, 128, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, float, int8_t> >(params, stream);
} else {
run_int8_flash_splitkv_fwd<Flash_fwd_kernel_traits<128, 128, 32, kBlockN, 32, 32, 32, STAGES, false, false, T, T, int8_t> >(params, stream);
}
#endif
}
\ No newline at end of file
#pragma once
#include "numeric_types.h"
#include "splitkv.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename accumType, typename reduceType, const int SPLIT_COUNT, const bool UnRoll, const bool Tail, const int kHeadDim, typename Params>
__global__ void flash_fwd_splitkv_reduce_kernel(
Params params) {
static_assert(SPLIT_COUNT <= 64 and (kHeadDim % 128 == 0 or kHeadDim == 64));
int num_splits = UnRoll ? SPLIT_COUNT: params.num_splits;
float* scores_max_ptr = params.scores_max_ptr;
float* scores_sum_ptr = params.scores_sum_ptr;
// 128 threads, each thread won't process more than 4 half data, 2 is appropriate, for 64 threads to processing 128 half
__shared__ float lds[512];
int tx = threadIdx.x;
int block_x = blockIdx.x;
int s_m_split_stride = gridDim.x; // offset from the next split
// recompute the true actual_seqlen_k and num_split
const int bidb = block_x / (params.h * params.seqlen_q);
int actual_seqlen_k;
if (params.is_seqlens_k_cumulative) {
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb];
} else {
actual_seqlen_k = params.cu_seqlens_k[bidb];
}
// compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
// each thread dispatch 1 piece of buffer_load; outliners will be assigned minimum value
float s_max_load_ori = exceed_split ? -INFINITY: scores_max_ptr[block_x + tx * s_m_split_stride];
// in a warp, reduce a true max value among 64 threads
float s_max_tmp = s_max_load_ori;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step));
}
// compute rescale coefficient for max (numerator)
float s_max_ratio = __expf(s_max_load_ori - s_max_tmp);
// as above, reduce a true sum value amoing 64 threads in each wave
float s_sum_load_ori = exceed_split ? 0.f: scores_sum_ptr[block_x + tx * s_m_split_stride];
float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step);
}
// max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp;
// finally, do rescale for each split and reduce the sum of them
// each block(1waves) process (num_splits x head_dim) elements in total
// for head_dim 128, each thread process 2 halfs for num_splits times
constexpr int tx_float_count = kHeadDim >> 6;
float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr);
// num_splits may not be 64, and thus need boundary judgement
for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split
float s_scale = lds[i];
bool within_splits = (i < true_num_splits);
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
// read ultimate scale value for current split
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// read ultimate scale value for current split
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
// switch to next split
tx_offset += oaccum_stride;
}
// write results
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result;
}
}
}
template<typename accumType, typename reduceType, const int SPLIT_COUNT, const bool UnRoll, const bool Tail, const int kHeadDim, typename Params>
__global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
static_assert(SPLIT_COUNT % 128 == 0 and (kHeadDim % 128 == 0 or kHeadDim == 64));
constexpr int WAVES_COUNT = SPLIT_COUNT >> 6;
int num_splits = UnRoll ? SPLIT_COUNT: params.num_splits;
float* scores_max_ptr = params.scores_max_ptr;
float* scores_sum_ptr = params.scores_sum_ptr;
// each workgroup need SPLIT_COUNT threads to process SPLIT_COUNT num_splits, each thread process (kHeadDim / 64) floats for final accumulation
constexpr int LDS_ACCUM = (SPLIT_COUNT * (kHeadDim >> 6));
// prepare workspace of 2 floats to reduce max/sum in 2 waves
constexpr int LDS_SIZE = LDS_ACCUM + (SPLIT_COUNT >> 6);
static_assert (LDS_SIZE * sizeof(float) <= 64 * 1024 and "Exceed max lds usage!");
__shared__ float lds[LDS_SIZE];
int tx = threadIdx.x;
int block_x = blockIdx.x;
int s_m_split_stride = gridDim.x; // offset from the next split
// recompute the true actual_seqlen_k and num_split
const int bidb = block_x / (params.h * params.seqlen_q);
int actual_seqlen_k;
if (params.is_seqlens_k_cumulative) {
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb];
} else {
actual_seqlen_k = params.cu_seqlens_k[bidb];
}
// compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
// each thread dispatch 1 piece of buffer_load; outliners will be assigned minimum value
float s_max_load_ori = exceed_split ? -INFINITY: scores_max_ptr[block_x + tx * s_m_split_stride];
// in a warp, reduce a true max value among 64 threads
float s_max_tmp = s_max_load_ori;
#pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step));
}
// for multiple waves, store the reduced max value to lds individually, and recompute max across multiple waves
int wave_id = (tx >> 6);
lds[LDS_ACCUM + wave_id] = s_max_tmp;
__syncthreads();
float lds_accum_temp = lds[LDS_ACCUM];
#pragma unroll
for (int s = 1; s < WAVES_COUNT; ++s) {
lds_accum_temp = max(lds_accum_temp, lds[LDS_ACCUM + s]);
}
lds[LDS_ACCUM] = lds_accum_temp;
__syncthreads();
// acquire the reduced max value across multiple waves
s_max_tmp = lds[LDS_ACCUM];
// compute rescale coefficient for max (numerator)
float s_max_ratio = __expf(s_max_load_ori - s_max_tmp);
// as above, reduce a true sum value amoing 64 threads in each wave
float s_sum_load_ori = exceed_split ? 0.f: scores_sum_ptr[block_x + tx * s_m_split_stride];
float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step);
}
// for multiple wave, store the reduced sum value to lds individually, and recompute sum across multiple waves
lds[LDS_ACCUM + wave_id] = s_sum_tmp;
__syncthreads();
lds_accum_temp = lds[LDS_ACCUM];
#pragma unroll
for (int s = 1; s < WAVES_COUNT; ++s) {
lds_accum_temp = lds_accum_temp + lds[LDS_ACCUM + s];
}
lds[LDS_ACCUM] = lds_accum_temp;
__syncthreads();
s_sum_tmp = lds[LDS_ACCUM];
// max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp;
// finally, do rescale for each split and reduce the sum of them
// each block(multiple waves) process (num_splits x head_dim) elements in total
// e.g. for head_dim 128, each thread process 2 halfs for num_splits times
// e.g. for head_dim 512, each thread process 8 halfs for num_splits times
constexpr int tx_float_count = kHeadDim >> 6;
float tx_accum[tx_float_count] = {0.f};
static_assert (tx_float_count * 128 < LDS_SIZE && "for each thread, it's not allowed to processing more than 8 half data");
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
// each wave read data from 0 in 128 halfs, and thus (tx % 64)
// int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count;
int in_batch_offset = block_x - bidb * params.h * params.seqlen_q;
int bidh = in_batch_offset / params.seqlen_q;
int bids = in_batch_offset - bidh * params.seqlen_q;
int real_block_x = params.layout == 0 ? block_x/*bhsd layout*/: bidb * params.seqlen_q * params.h + bids * params.h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
int begin = wave_id << 6;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
// for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ......
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr) + begin * oaccum_stride;
// num_splits may not be multiple of 64, and thus, the multiple waves need boundary judgement
int split_count_this_wave = UnRoll ? 64: min(64, num_splits - begin);
for (int i = 0; i < split_count_this_wave; ++i) {
// read ultimate scale value for current split
float s_scale = lds[begin + i];
bool within_splits = (begin + i) < true_num_splits;
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
// read 2 halfs from current split of this threads
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// read 1 half from current split of this threads
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
// switch to next split
tx_offset += oaccum_stride;
}
// no ds_read op again
__syncthreads();
// for multiple waves, store sum value to lds
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
lds[tx * tx_float_count + t] = tx_accum[t];
if constexpr (kHeadDim % 128 == 0) {
lds[tx * tx_float_count + t + 1] = tx_accum[t + 1];
}
}
__syncthreads();
// the 0th wave does the reduction and write operations
if (wave_id == 0) {
using vec2_fp32 = __attribute__((__vector_size__(2 * sizeof(float)))) float;
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
vec2_fp32 this_wave_f32s = *(vec2_fp32*)(lds + tx * tx_float_count + t);
#pragma unroll
for (int s = 1; s < (SPLIT_COUNT >> 6); ++s) { // 0 wave accumulate data from other waves
vec2_fp32 other_wave_f32s = *(vec2_fp32*)(lds + tx * tx_float_count + t + s * 64 * tx_float_count);
this_wave_f32s[0] += other_wave_f32s[0];
this_wave_f32s[1] += other_wave_f32s[1];
}
*(vec2_fp32*)(lds + tx * tx_float_count + t) = this_wave_f32s;
} else if constexpr (kHeadDim == 64) {
float this_wave_f32s = *(float*)(lds + tx * tx_float_count + t);
#pragma unroll
for (int s = 1; s < (SPLIT_COUNT >> 6); ++s) { // 0 wave accumulate data from other waves
float other_wave_f32s = *(float*)(lds + tx * tx_float_count + t + s * 64 * tx_float_count);
this_wave_f32s += other_wave_f32s;
}
*(float*)(lds + tx * tx_float_count + t) = this_wave_f32s;
}
}
__syncthreads(); // here, __sync may not be necessary
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
tx_accum[t] = lds[tx * tx_float_count + t];
tx_accum[t + 1] = lds[tx * tx_float_count + t + 1];
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
tx_accum[t] = lds[tx * tx_float_count + t];
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
*(reduceType*)(output_ptr + t) = accum_result;
}
}
}
}
template<typename accumType, typename reduceType, const int SPLIT_COUNT, const bool UnRoll, const bool Tail, const int kHeadDim, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel(
Params params) {
static_assert(SPLIT_COUNT <= 64 and (kHeadDim % 128 == 0 or kHeadDim == 64));
constexpr int num_splits = SPLIT_COUNT;
/* bottleneck
1. s_load_dword latency for args
2. s_load_dword for cu_seqlens_q/cu_seqlens_k, cache miss
solution:
1. why kernel args cannot be hit on SQC data cache ? kernel packet different ?
2. llvm-backend rescheduling, overlap args loading and cu_seqlens loading with better granularity
3. asm
*/
// 128 threads, each thread won't process more than 4 half data, 2 is appropriate, for 64 threads to processing 128 half
__shared__ float lds[512];
int tx = threadIdx.x;
int total_q = params.total_q;
int bidh_ngroup = blockIdx.x;
int total_h_ngroup = gridDim.x;
int s_m_split_stride = total_h_ngroup * total_q; // offset from the next split
int bidh = bidh_ngroup / params.ngroups;
int group_id = bidh_ngroup - bidh * params.ngroups;
// recompute the true actual_seqlen_k and num_split
const int bidb = blockIdx.y;
int actual_seqlen_k = params.cu_seqlens_k[bidb];
// varlen q
int sum_s_q = params.cu_seqlens_q[bidb];
int actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
// compute partition_size when fix num_splits
int partition_size = splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits);
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
float* softmax_lse_ptr = reinterpret_cast<float*>(params.softmax_lse_ptr);
float* softmax_lseaccum_ptr = reinterpret_cast<float*>(params.softmax_lseaccum_ptr);
for (int cur_s_q = 0; cur_s_q < actual_seqlen_q; ++cur_s_q) {
// h * ngroups * (bs)
int block_x = bidh_ngroup * total_q + sum_s_q + cur_s_q;
// load local lse value for each split
float lse_local = softmax_lseaccum_ptr[block_x + min(tx, num_splits - 1) * s_m_split_stride];
__builtin_amdgcn_sched_barrier(0);
// finally, do rescale for each split and reduce the sum of them
// each block(1waves) process (num_splits x head_dim) elements in total
// for head_dim 128, each thread process 2 halfs for num_splits times
constexpr int tx_float_count = kHeadDim >> 6;
float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
// {total_q, ngroups, num_heads, -1} --> {total_q, num_heads, ngroups, -1}
// int real_block_x = (sum_s_q + cur_s_q) * total_h_ngroup + group_id * params.h + bidh;
int real_block_x = (sum_s_q + cur_s_q) * total_h_ngroup + bidh * params.ngroups + group_id;
int tx_offset = real_block_x * kHeadDim + (tx & 63) * tx_float_count;
reduceType* output_ptr = reinterpret_cast<reduceType*>(params.o_ptr) + tx_offset;
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr);
// prefetch all vgprs
constexpr int tx_float_loop = tx_float_count >> 1;
vec2_Element<accumType> load_vec[num_splits][tx_float_loop];
accumType load[num_splits][tx_float_loop];
// num_splits may not be 64, and thus need boundary judgement
#pragma unroll
for (int i = 0; i < num_splits; ++i) {
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
load_vec[i][t >> 1] = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
} else if constexpr (kHeadDim == 64) {
load[i][t >> 1] = *(accumType*)(oaccum_ptr + tx_offset + t);
}
}
// switch to next split
tx_offset += oaccum_stride;
}
__builtin_amdgcn_sched_barrier(0);
// process initialization as -inf
if (exceed_split) {
lse_local = -INFINITY;
}
// reduce max lse
float lse_max_local = lse_local;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step));
}
// reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step);
}
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
// store softmax_lse
if (tx == 0) {
softmax_lse_ptr[block_x] = lse_local_logsum;
}
// store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum);
// num_splits may not be 64, and thus need boundary judgement
#pragma unroll
for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split
float s_scale = lds[i];
bool within_splits = (i < true_num_splits);
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i][t >> 1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
}
// write results
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result;
}
}
}
}
template<typename accumType, typename reduceType, const int SPLIT_COUNT, const bool UnRoll, const bool Tail, const int kHeadDim, typename Params>
__global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
Params params) {
static_assert(SPLIT_COUNT <= 64 and (kHeadDim % 128 == 0 or kHeadDim == 64));
constexpr int WARP_NUM = 4;
constexpr int lds_required_per_wave = (kHeadDim >> 2);
constexpr int num_splits = SPLIT_COUNT; // UnRoll ? SPLIT_COUNT: params.num_splits;
// acquire scalar data by sqc cache
float* softmax_lse_ptr = params.softmax_lse_ptr;
accumType* oaccum_ptr = reinterpret_cast<accumType*>(params.oaccum_ptr);
void* o_ptr = params.o_ptr;
int32_t* cu_seqlens_k = params.cu_seqlens_k;
int params_num_splits = params.num_splits;
int params_partition_size = params.partition_size;
int h = params.h;
int seqlen_q = params.seqlen_q;
int layout = params.layout;
// acquire task
int tx = threadIdx.x & 63;
int wave_id = threadIdx.x >> 6;
int block_x = blockIdx.x;
int s_m_split_stride = gridDim.x; // offset from the next split
const int bidb = block_x / (h * seqlen_q);
// prefetch buffer to overlap with s_load_dword*
float lse_local = softmax_lse_ptr[block_x + min(tx, num_splits - 1) * s_m_split_stride];
// share rescale across threads
__shared__ float lds_space[512 + 1024];
float* lds = lds_space + wave_id * lds_required_per_wave;
// recompute the true actual_seqlen_k and num_split
// int actual_seqlen_k = params.topk_length[block_x / 64];//cu_seqlens_k[bidb];
// int row = block_x / 64; // 你当前 h_q=64 路径
// int main_len = params.topk_length ? params.topk_length[row] : params.topk;
// int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
// int actual_seqlen_k = ceil_div(main_len, 64) * 64 + ceil_div(extra_len, 64) * 64;
int row = block_x / 64;
int main_len = params.topk_length ? params.topk_length[row] : params.topk;
int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
int total_blocks = ceil_div(main_len, 64) + ceil_div(extra_len, 64);
int blocks_per_split = ceil_div(params.partition_size, 64);
int true_num_splits = ceil_div(total_blocks, blocks_per_split);
// for flashmla, 512 elements are engaged to 4 blocks
// within each block, num_splits / WARM_NUM load transactions are engaged to each wave
constexpr int tx_float_count = (kHeadDim >> 2) >> 6;
float tx_accum[tx_float_count] = {0.f};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int oaccum_stride = s_m_split_stride * kHeadDim;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int in_batch_offset = block_x - bidb * h * seqlen_q;
int bidh = in_batch_offset / seqlen_q;
int bids = in_batch_offset - bidh * seqlen_q;
int real_block_x = layout == 0 ? block_x/*bhsd layout*/: bidb * seqlen_q * h + bids * h + bidh/*bshd layout*/;
int tx_offset = real_block_x * kHeadDim + tx * tx_float_count + blockIdx.y * (kHeadDim >> 2) + min(wave_id, num_splits - 1) * oaccum_stride;
reduceType* output_ptr = reinterpret_cast<reduceType*>(o_ptr) + tx_offset;
// fetch all data into vgprs
constexpr int SPLITS_PER_WAVE = std::max<int32_t>(1, num_splits >> 2);
vec2_Element<accumType> load[SPLITS_PER_WAVE][tx_float_count >> 1];
#pragma unroll
for (int i = 0; i < num_splits; i += WARP_NUM) {
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
load[i >> 2][t >> 1] = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
}
// switch to next split
tx_offset += WARP_NUM * oaccum_stride;
}
// compute partition_size when fix num_splits
// int partition_size = params_partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params_partition_size;
// const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
// process initialization as -inf
if (exceed_split) {
lse_local = -INFINITY;
}
// reduce max lse
float lse_max_local = lse_local;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step));
}
// reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step);
}
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
float attn_sink_o_scale = 1.0f;
if (params.attn_sink != nullptr) {
// 当前 reduce kernel 的 block_x 是按 b,h,s 展开的,所以 bidh 就是 head id。
float rAttn_sink = params.attn_sink[block_x % 64];
if (rAttn_sink == INFINITY) {
attn_sink_o_scale = 0.0f;
} else if (lse_local_logsum != -INFINITY && lse_local_logsum != INFINITY) {
float lse_exp = __expf(lse_local_logsum);
float sink_exp = __expf(rAttn_sink);
attn_sink_o_scale = lse_exp / (lse_exp + sink_exp);
}
}
// store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum) * attn_sink_o_scale;
// num_splits may not be 64, and thus need boundary judgement
#pragma unroll
for (int i = 0; i < num_splits; i += WARP_NUM) {
// read ultimate scale value for current split
bool within_splits = ((i + wave_id) < true_num_splits);
float s_scale = num_splits >= WARP_NUM ? lds[i + wave_id]: (within_splits ? lds[i + wave_id]: 0.f);
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
}
}
// reduce across 4 waves
float *reduce_lds = lds_space + 512;
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
reduce_lds[wave_id * lds_required_per_wave + tx + t * 64] = tx_accum[t];
reduce_lds[wave_id * lds_required_per_wave + tx + (t + 1) * 64] = tx_accum[t + 1];
}
__syncthreads();
if (wave_id == 0) {
// write results
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
// get data from other wave
#pragma unroll
for (int neighbor = 1; neighbor < WARP_NUM; ++neighbor) {
tx_accum[t] += reduce_lds[neighbor * lds_required_per_wave + tx + t * 64];
tx_accum[t + 1] += reduce_lds[neighbor * lds_required_per_wave + tx + (t + 1) * 64];
}
// cvt
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]);
#endif
// storation
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
}
}
}
\ No newline at end of file
...@@ -2,11 +2,17 @@ ...@@ -2,11 +2,17 @@
#include <stdexcept> #include <stdexcept>
#include "dsa_mls/fwd.h"
#include "phase1.h" #include "phase1.h"
namespace gfx93 { namespace gfx93 {
void run_fwd_kernel(const SparseAttnFwdParams& params) { void run_fwd_kernel(const SparseAttnFwdParams& params) {
if (gfx93::fwd::dsa_mls::should_run(params)) {
gfx93::fwd::dsa_mls::run(params);
return;
}
const bool have_topk_length = params.topk_length != nullptr; const bool have_topk_length = params.topk_length != nullptr;
// Dispatch based on d_qk dimension and presence of topk_length // Dispatch based on d_qk dimension and presence of topk_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment