Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#pragma once
#include "int8_kvcache_qk_gemm_prefetch_v_3stage.h"
#define USE_DS_OVERLAP_MMAC
/*
* gQ: Transposed 32x16 matrix
* gK: Non-transposed 32x16 matrix
* s_ptr: Non-transposed 32x32 matrix
*/
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int kBlockK_v, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename Element_k, typename ElementAccum>
__forceinline__ __device__ void int8_kvcache_qk_gemm_prefetch_v(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element_k* q_lds,
Element_k* k_lds,
Element_k* v_lds,
vec4_int8 q_reg[(kHeadDim/kBlockK)*((WARP_M*kBlockK)/(32*kBlockK))*2][4],
vec4_int32 s_reg[(WARP_M/32)*(WARP_N/32)][4],
int WARP_ID,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset = -1) {
// static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
// union_vec4_f16x2<Element> k_reg[STAGES*((WARP_N*kBlockK)/(32*32))*2];
vec4_int8 k_reg[STAGES*((WARP_N*kBlockK)/(32*32))*4][4];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1)*2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = (WARP_ID & (WARP_N/WARP_N - 1));
int k_ds_read_offset = k_warp_n_id*(WARP_N/32)*(32*17) + (lane_id & 1)*16 + (laneid_and_15>>1)*65 + (laneid_shfl_4 & 1)*8 + (lane_id/32);
constexpr int k_lds_load_num = (WARP_N*kBlockK) / (4*32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
int stage_id = 0;
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
//to_be_modified
for(int k_loop = K_LOOP_START; k_loop<(kHeadDim/kBlockK); k_loop++) {
if constexpr (true) {
// int k_lane_seq_idx = (laneid_shfl_4);
// neighbour sequence is in the same thread --->(seq0, seq1) in thread0, (seq2, seq3) in thread1...
// int k_lane_seq_idx = ((laneid_shfl_4) & 1)*2 + ((laneid_shfl_4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
// int k_lane_head_dim_idx = laneid_and_15;
int k_block_buffer_load_global_offset = k_loop * kBlockK + WARP_ID * WARP_N * kcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK/32)*(32*34) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK/32)*(32*34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = (__load & (WARP_N/4 - 1));
// int k_warp_buffer_load_k_id = (warp_loop / (WARP_N/4));
int k_warp_buffer_load_lds_offset = k_lds_stage_offset/* + (k_warp_buffer_load_k_id * WARP_N * 34)*/ + ((k_warp_buffer_load_n_id >> 3)*(32*34) + (k_warp_buffer_load_n_id & 7)*(4*32)) ; ;
// int k_warp_buffer_load_global_offset = (k_warp_buffer_load_k_id * 32);
int gvOffset_s = (k_block_buffer_load_global_offset/* + k_warp_buffer_load_global_offset*/) / 2;
int gvOffset_v = ((min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kcache_seqlen_stride) / 2 + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / 2;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, gvOffset_s, gvOffset_v);
}
}
//to_be_modified
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^=1;
int precompute_k_lds_offset[2*2];
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK/32)*(32*17) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK/32)*(32*17);
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
vec4_int8 *k_lds_v4int8 = (vec4_int8 *)(k_lds);
for(int i=0; i<2; i++) {
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
for(int head_dim_idx=0; head_dim_idx<(kBlockK/32); head_dim_idx++) {
for(int j=0; j<2; j++) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v4int8) + (k_lds_stage_offset + head_dim_idx*(WARP_N*17) + n_idx*(32*17) + j*4 + i*32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>(); // 对于当前的写法, 每个 wave 处理自己的数据, 不需要 wave 同步; 直到计算最大值
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
// int k_warp_n_id = (WARP_ID & (WARP_N/WARP_N - 1));
// int k_lds_stage_offset = stage_id * (WARP_N/32) * (kBlockK/32)*(32*17);
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int j=0; j<2; j++) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
// inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i].u64[j], 2);
// inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i][j], 2);
inline_ds_read_b32_wait(k_lds_v4int8, lds_offset, k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i][j]);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
//to_be_modified
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
// vec4_Element<Element>{q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1]},
// vec4_Element<Element>{k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][0],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][1],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][0],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][1]},
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m] = __builtin_hcu_mmac_i32_16x16x32_i8(vec8_int8{q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][2],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][3],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][2],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][3]},
vec8_int8{k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][0],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][1],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][2],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][3],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][0],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][1],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][2],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][3]},
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m]);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
// vec4_Element<Element>{q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
// q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1]},
// vec4_Element<Element>{k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][0],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][1],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][0],
// k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][1]},
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m] = __builtin_hcu_mmac_i32_16x16x32_i8(vec8_int8{q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][2],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][3],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][2],
q_reg[(k_loop_idx)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][3]},
vec8_int8{k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][0],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][1],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][2],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][3],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][0],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][1],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][2],
k_reg[stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][3]},
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m]);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
//to_be_modified
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
vec4_int8 *k_lds_v4int8 = (vec4_int8 *)(k_lds);
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK/32)*(32*17) + last_stage_id * WARP_NUM * (WARP_N/32) * (kBlockK/32)*(32*17);
for(int head_dim_idx=0; head_dim_idx<(kBlockK/32); head_dim_idx++) {
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
for(int i=0; i<2; i++) {
for(int j=0; j<2; j++) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v4int8) + (k_lds_stage_offset + head_dim_idx*(WARP_N*17) + n_idx*(32*17) + j*4 + i*32 + k_ds_read_offset) * 4;
}
}
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
// int k_warp_n_id = (WARP_ID & (WARP_N/WARP_N - 1));
// int k_lds_stage_offset = stage_id * (WARP_N/32) * (kBlockK/32)*(32*17);
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
vec4_int8 *k_lds_v4int8 = (vec4_int8 *)(k_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
for(int i=0; i<2; i++) {
for(int j=0; j<2; j++) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
// inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i][j], 2);
inline_ds_read_b32_wait(k_lds_v4int8, lds_offset, k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i][j]);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
//to_be_modified
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
// vec4_Element<Element>{q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1]},
// vec4_Element<Element>{k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][0],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][1],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][0],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][1]},
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m] = __builtin_hcu_mmac_i32_16x16x32_i8(vec8_int8{q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][2],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][3],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][2],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][3]},
vec8_int8{k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][0],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][1],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][2],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][3],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][0],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][1],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][2],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][3]},
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m]);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (kBlockK/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
// vec4_Element<Element>{q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
// q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1]},
// vec4_Element<Element>{k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][0],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k][1],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][0],
// k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n].f16x2[2*min_tile_k + 1][1]},
// s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m] = __builtin_hcu_mmac_i32_16x16x32_i8(vec8_int8{q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][0],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][1],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][2],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k][3],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][0],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][1],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][2],
q_reg[((kHeadDim/kBlockK)-1)*((WARP_M*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][2*min_tile_k + 1][3]},
vec8_int8{k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][0],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][1],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][2],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k][3],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][0],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][1],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][2],
k_reg[last_stage_id*((WARP_N*kBlockK)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + min_tile_n][2*min_tile_k + 1][3]},
s_reg[(n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m]);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
int8_kvcache_prefetch_v_to_lds<kHeadDimV, kBlockM, kBlockK_v, kBlockN, WARP_M, kBlockK_v, 32/*WARP_K*/, 0, WARP_NUM, Element_k, STAGES>(gV, v_lds, WARP_ID, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "int8_kvcache_qk_gemm_utils.h"
#include "int8_kvcache_pv_gemm_utils.h"
#define mmac() \
if (bytes_per_Element == 1){\
s_reg[0][min_tile_n*2 + min_tile_m].int32 = __builtin_hcu_mmac_i32_16x16x32_i8(vec8_int8{q_reg[q_idx][2*min_tile_k][0],\
q_reg[q_idx][2*min_tile_k][1],\
q_reg[q_idx][2*min_tile_k][2],\
q_reg[q_idx][2*min_tile_k][3],\
q_reg[q_idx][2*min_tile_k + 1][0],\
q_reg[q_idx][2*min_tile_k + 1][1],\
q_reg[q_idx][2*min_tile_k + 1][2],\
q_reg[q_idx][2*min_tile_k + 1][3]},\
vec8_int8{k_reg[k_idx].f8x4[2*min_tile_k][0],\
k_reg[k_idx].f8x4[2*min_tile_k][1],\
k_reg[k_idx].f8x4[2*min_tile_k][2],\
k_reg[k_idx].f8x4[2*min_tile_k][3],\
k_reg[k_idx].f8x4[2*min_tile_k + 1][0],\
k_reg[k_idx].f8x4[2*min_tile_k + 1][1],\
k_reg[k_idx].f8x4[2*min_tile_k + 1][2],\
k_reg[k_idx].f8x4[2*min_tile_k + 1][3]},\
s_reg[0][min_tile_n*2 + min_tile_m].int32);\
}
/*
* 3 阶段的 pingpang buffer, 写法跟之前的 2 阶段差异较大, 因此没统一在一起
* 至于 4 阶段, 暂不考虑, 因为 max_lds 跟 V 不能复用, 4 个 wave、4 倍的 LDS 用量, 直接干到 32KB 了, 放不下 max_lds, 因此暂时不适用 4 阶段的 pingpang buffer
*/
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int kBlockK_v, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename Element_q, typename ElementAccum>
__forceinline__ __device__ void int8_kvcache_qk_gemm_prefetch_v_3stage(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element_q* q_lds,
Element_q* k_lds,
Element_q* v_lds,
vec4_int8 q_reg[(kHeadDim/kBlockK)*((WARP_M*kBlockK)/(32*kBlockK))*2][4],
union_vec4_int32 s_reg[(WARP_M/32)*(WARP_N/32)][4],
/*vec4_Accum<ElementAccum> s_reg[(WARP_M/32)*(WARP_N/32)][4],*/
int WARP_ID,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset = -1) {
//计算常量,每个元素的字节数
const int bytes_per_Element = 1;
//计算常量,每个dword中包含的元素个数
int Element_per_dword = 4/bytes_per_Element;
// vec4_int8 k_reg[STAGES*((WARP_N*kBlockK*bytes_per_Element/2)/(32*32))*4][4];
union_vec4_f16x2<int8_t> k_reg[STAGES*((WARP_N*kBlockK*bytes_per_Element/2)/(32*32))*2];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1)*2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = (WARP_ID & (WARP_N/WARP_N - 1));
int k_ds_read_offset = k_warp_n_id*(WARP_N/32)*(32*17) + (lane_id & 1)*16 + (laneid_and_15>>1)*65 + (laneid_shfl_4 & 1)*8 + (lane_id/32);
int k_lds_load_num = (WARP_N*kBlockK) / (4*64);
int K_LOAD_REQUESTS = k_lds_load_num;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_q, 2>;
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg);
} else {
inline_vgpr4_init_zero_1x4x4(s_reg);
}
__builtin_amdgcn_sched_barrier(0);
#else
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
#endif
int k_loop = 0;
int stage_id = 0;
// for(int k_loop = K_LOOP_START; k_loop<(kHeadDim/kBlockK); k_loop++)
{
stage_id = 0;
//当使用int8,kBlockK由32变为64,所以下面的计算会发生变化,需要重新计算,to_be_modified
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2*2];
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK*bytes_per_Element/2/32)*(32*17) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK*bytes_per_Element/2/32)*(32*17);
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
vec4_int8 *k_lds_v4int8 = (vec4_int8 *)(k_lds);
for(int i=0; i<2; i++) {
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
for(int j=0; j<2; j++) {
precompute_k_lds_offset[i * 2 + j] = (k_lds_stage_offset + head_dim_idx*(WARP_N*17) + n_idx*(32*17) + j*4 + i*32 + k_ds_read_offset);
}
}
}
}
vmcnt_wait(K_LOAD_REQUESTS);
// vmcnt_wait(0);
//to_be_modified
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
// int k_warp_n_id = (WARP_ID & (WARP_N/WARP_N - 1));
// int k_lds_stage_offset = stage_id * (WARP_N/32) * (kBlockK*bytes_per_Element/2/32)*(32*17);
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/64); head_dim_idx++) {
#pragma unroll
for(int j=0; j<2; j++) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
int reg_offset = stage_id * 2 + i;
// inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id*((WARP_N*kBlockK*bytes_per_Element/2)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i].u64[j], 2);
k_reg[reg_offset].f8x4[j*2][0] = k_lds_v4int8[lds_offset][0];
k_reg[reg_offset].f8x4[j*2][1] = k_lds_v4int8[lds_offset][1];
k_reg[reg_offset].f8x4[j*2][2] = k_lds_v4int8[lds_offset][2];
k_reg[reg_offset].f8x4[j*2][3] = k_lds_v4int8[lds_offset][3];
k_reg[reg_offset].f8x4[j*2+1][0] = k_lds_v4int8[lds_offset+2][0];
k_reg[reg_offset].f8x4[j*2+1][1] = k_lds_v4int8[lds_offset+2][1];
k_reg[reg_offset].f8x4[j*2+1][2] = k_lds_v4int8[lds_offset+2][2];
k_reg[reg_offset].f8x4[j*2+1][3] = k_lds_v4int8[lds_offset+2][3];
}
}
}
}
// #ifdef USE_DS_OVERLAP_MMAC
// asm volatile("s_waitcnt lgkmcnt(2)");
// #else
// asm volatile("s_waitcnt lgkmcnt(0)");
// #endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
//block_k增大为64,循环体可能需要修改,to_be_modified
{
int min_tile_n = 0;
for(int head_dim_idx=0; head_dim_idx< (kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = k_loop;
int q_idx = k_loop_idx*2 + (head_dim_idx*(WARP_M/32))*2 + min_tile_m;
int k_idx = stage_id*2 + head_dim_idx*(WARP_N/32)*2 + min_tile_n;
mmac();
}
}
}
}
// #ifdef USE_DS_OVERLAP_MMAC
// asm volatile("s_waitcnt lgkmcnt(0)");
// #endif
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
{
int min_tile_n = 1;
for(int head_dim_idx=0; head_dim_idx< (kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = k_loop;
int q_idx = k_loop_idx*2 + (head_dim_idx*(WARP_M/32))*2 + min_tile_m;
int k_idx = stage_id*2 + head_dim_idx*(WARP_N/32)*2 + min_tile_n;
mmac();
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
k_loop = 1;
//to_be_modified
{
{
stage_id = 1;
// int k_lane_seq_idx = (laneid_shfl_4);
// neighbour sequence is in the same thread --->(seq0, seq1) in thread0, (seq2, seq3) in thread1...
// int k_lane_seq_idx = ((laneid_shfl_4) & 1)*2 + ((laneid_shfl_4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
// int k_lane_head_dim_idx = laneid_and_15;
int k_block_buffer_load_global_offset = (k_loop) * kBlockK + WARP_ID * WARP_N * kcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK/32)*(32*34) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK/32)*(32*34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
// int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (load & 7); // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = (load & (WARP_N/4 - 1));
int gvOffset_s = (k_block_buffer_load_global_offset/* + k_warp_buffer_load_global_offset*/) / Element_per_dword;
int gvOffset_v = ((min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kcache_seqlen_stride) / Element_per_dword + qk_lane_head_dim_idx;
int lds_offset = (k_lds_stage_offset) / Element_per_dword + padding + ((k_warp_buffer_load_n_id >> 3)*(32*17) + (k_warp_buffer_load_n_id & 7)*64);
BUFFER_LOAD_FUNC(k_lds, gK, lds_offset, gvOffset_s, gvOffset_v);
}
}
stage_id = 1;
//to_be_modified
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2*2];
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK*bytes_per_Element/2/32)*(32*17) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK*bytes_per_Element/2/32)*(32*17);
vec4_int8 *k_lds_v4int8 = (vec4_int8 *)(k_lds);
for(int i=0; i<2; i++) {
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
for(int j=0; j<2; j++) {
precompute_k_lds_offset[i * 2 + j] = (k_lds_stage_offset + head_dim_idx*(WARP_N*17) + n_idx*(32*17) + j*4 + i*32 + k_ds_read_offset);
}
}
}
}
vmcnt_wait(0);
//to_be_modified
if constexpr (true) {
// lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int j=0; j<2; j++) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
int reg_offset = stage_id * 2 + i;
// inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id*((WARP_N*kBlockK*bytes_per_Element/2)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx)*2 + i].u64[j], 2);
k_reg[reg_offset].f8x4[j*2][0] = k_lds_v4int8[lds_offset][0];
k_reg[reg_offset].f8x4[j*2][1] = k_lds_v4int8[lds_offset][1];
k_reg[reg_offset].f8x4[j*2][2] = k_lds_v4int8[lds_offset][2];
k_reg[reg_offset].f8x4[j*2][3] = k_lds_v4int8[lds_offset][3];
k_reg[reg_offset].f8x4[j*2+1][0] = k_lds_v4int8[lds_offset+2][0];
k_reg[reg_offset].f8x4[j*2+1][1] = k_lds_v4int8[lds_offset+2][1];
k_reg[reg_offset].f8x4[j*2+1][2] = k_lds_v4int8[lds_offset+2][2];
k_reg[reg_offset].f8x4[j*2+1][3] = k_lds_v4int8[lds_offset+2][3];
}
}
}
}
// #ifdef USE_DS_OVERLAP_MMAC
// asm volatile("s_waitcnt lgkmcnt(2)");
// #else
// asm volatile("s_waitcnt lgkmcnt(0)");
// #endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
//to_be_modified
{
int min_tile_n = 0;
for(int head_dim_idx=0; head_dim_idx< (kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = k_loop;
int q_idx = k_loop_idx*2 + (head_dim_idx*(WARP_M/32))*2 + min_tile_m;
int k_idx = stage_id*2 + head_dim_idx*(WARP_N/32)*2 + min_tile_n;
mmac();
}
}
}
}
// #ifdef USE_DS_OVERLAP_MMAC
// asm volatile("s_waitcnt lgkmcnt(0)");
// #endif
__builtin_amdgcn_sched_barrier(0);
//to_be_modified
{
int min_tile_n = 1;
for(int head_dim_idx=0; head_dim_idx< (kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
int k_loop_idx = k_loop;
int q_idx = k_loop_idx*2 + (head_dim_idx*(WARP_M/32))*2 + min_tile_m;
int k_idx = stage_id*2 + head_dim_idx*(WARP_N/32)*2 + min_tile_n;
mmac();
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
// __syncthreads();
// 可以先把需要的 V load 指令发下去;
int8_kvcache_prefetch_v_to_lds<kHeadDimV, kBlockM, kBlockK_v, kBlockN, WARP_M, kBlockK_v, 32/*WARP_K*/, 0, WARP_NUM, Element_q, STAGES>(gV, v_lds, WARP_ID, vcache_seqlen_stride, max_seq_k_offset);
} // qk_gemm
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, typename Element_q, int STAGES, int REUSE_KV_TIMES, int M_MMAC_COUNT>
__forceinline__ __device__ void int8_kvcache_prefetch_q_to_vgpr(
vec4_uint gQ,
Element_q* q_lds,
vec4_int8 q_reg[(kHeadDim/kBlockK)*((WARP_M*kBlockK)/(32*kBlockK))*2][4],
int WARP_ID,
int max_seq_q_offset = -1) {
//计算常量,每个元素的字节数
const int bytes_per_Element = 1;
//计算常量,每个dword中包含的元素个数
const int Element_per_dword = 4/bytes_per_Element;
// const int WARP_NUM = (kBlockM) / (WARP_M);
// const int q_lds_load_num = (kBlockM * kBlockK) / (4 * 32); // 32 * 32 / 4 * 32 = 8
// const int Q_LOAD_REQUESTS = q_lds_load_num / WARP_NUM; // 8 / 4 = 2
// static_assert(REUSE_KV_TIMES <= 16 and WARP_NUM == 4);
constexpr bool Is_GQA = M_MMAC_COUNT > 1;
constexpr int Q_LOAD_REQUESTS = Is_GQA ? ((REUSE_KV_TIMES + 1) >> 1) << 2 / WARP_NUM: 1/*MHA only need the first token*/;
constexpr int SEQUENCE_READ = Is_GQA ? 2: 1;
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = ((lane_id >> 4) & 1) * 2 + ((lane_id >> 4) >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int q_lane_head_dim_idx = lane_id & 15;
int stage_id = 0;
//to_be_modified
if constexpr (STAGES > 1) {
int k_loop = 0;
// global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM/32) * (kBlockK/32)*(32*34);
for(int load = 0, warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*4; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = (warp_loop & (kBlockM/4 - 1));
int q_warp_buffer_load_lds_offset = q_lds_stage_offset/* + (q_warp_buffer_load_k_id * kBlockM * 34)*/ + ((q_warp_buffer_load_m_id >> 3)*(32*68) + (q_warp_buffer_load_m_id & 7)*(4*64));
int gvOffset_s = (q_block_buffer_load_global_offset/* + q_warp_buffer_load_global_offset*/) / Element_per_dword;
int gvOffset_v = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (gvOffset_v < max_seq_q_offset) {
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / Element_per_dword;
// int lds_offset = (q_lds_stage_offset) / Element_per_dword + padding + (q_warp_buffer_load_m_id & 7)*64;
gvOffset_v = (gvOffset_v * kHeadDim) / Element_per_dword + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, gvOffset_s, gvOffset_v);
}
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
//to_be_modified
for(int k_loop = K_LOOP_START; k_loop<(kHeadDim/kBlockK); k_loop++) {
// global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM/32) * (kBlockK/32)*(32*34);
for(int load = 0, warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*4; // padding size in shared memory per buffer load, to avoid bank conflict
int q_warp_buffer_load_m_id = (warp_loop & (kBlockM/4 - 1));
int q_warp_buffer_load_lds_offset = q_lds_stage_offset/* + (q_warp_buffer_load_k_id * kBlockM * 34)*/ + ((q_warp_buffer_load_m_id >> 3)*(32*68) + (q_warp_buffer_load_m_id & 7)*(4*64));
int gvOffset_s = (q_block_buffer_load_global_offset/* + q_warp_buffer_load_global_offset*/) / Element_per_dword;
int gvOffset_v = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (gvOffset_v < max_seq_q_offset) {
// int lds_offset = (q_lds_stage_offset) / Element_per_dword + padding + (q_warp_buffer_load_m_id & 7)*64;
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / Element_per_dword;
gvOffset_v = (gvOffset_v * kHeadDim) / Element_per_dword + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, gvOffset_s, gvOffset_v);
}
}
//to_be_modified
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM/32) * (kBlockK*bytes_per_Element/2/32)*(32*17);
// vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
vec4_int8 *q_lds_v4int8 = (vec4_int8 *)(q_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int i=0; i<SEQUENCE_READ; i++) {
#pragma unroll
for(int j=0; j<4; j++) {
int lds_offset = q_lds_stage_offset + head_dim_idx*kBlockM*17 + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
// for mha, only 0/32/16/48 need read, and thus if (lane_id % 16 == 0), but (land_id & 15 == 0) will lead to errors
inline_ds_read_b32_wait(q_lds_v4int8, lds_offset, q_reg[k_loop_idx*((WARP_M*kBlockK)/(32*64))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j]);
}
}
}
}
__syncthreads();
}
//to_be_modified
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM/32) * (kBlockK*bytes_per_Element/2/32)*(32*17);
// vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
vec4_int8 *q_lds_v4int8 = (vec4_int8 *)(q_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockK*bytes_per_Element/2/32); head_dim_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int i=0; i<SEQUENCE_READ; i++) {
#pragma unroll
for(int j=0; j<4; j++) {
int lds_offset = q_lds_stage_offset + head_dim_idx*kBlockM*17 + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
// for mha, only 0/32/16/48 need read, and thus if (lane_id % 16 == 0), but (land_id & 15 == 0) will lead to errors
inline_ds_read_b32_wait(q_lds_v4int8, lds_offset, q_reg[((kHeadDim/kBlockK) - 1)*((WARP_M*kBlockK)/(32*64))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j]);
}
}
}
}
}
}
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_N, typename Element, typename Element_k, int STAGES, int WARP_NUM>
__forceinline__ __device__ void int8_kvcache_prefetch_k_to_lds(
vec4_uint k_ptr,
Element_k* k_lds,
int WARP_ID,
int kcache_seqlen_stride,
int max_seq_k_offset = -1) {
//计算常量,每个元素的字节数
const int bytes_per_Element = 1;
//计算常量,每个dword中包含的元素个数
int Element_per_dword = 4/bytes_per_Element;
// const int WARP_NUM = (kBlockM)/(WARP_M);
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1)*2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
constexpr int k_lds_load_num = (WARP_N*kBlockK) / (4*64);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_k, 2>;
int stage_id = 0;
int k_loop = 0;
//to_be_modified
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = k_loop * kBlockK + WARP_ID * WARP_N * kcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N/32)* (kBlockK/32)*(32*34) + stage_id * WARP_NUM * (WARP_N/32) * (kBlockK/32)*(32*34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
// int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (load & 7); // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_n_id = (load & (WARP_N/4 - 1));
int gvOffset_s = (k_block_buffer_load_global_offset/* + k_warp_buffer_load_global_offset*/) / Element_per_dword;
int gvOffset_v = ((min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kcache_seqlen_stride) / Element_per_dword + qk_lane_head_dim_idx;
int lds_offset = (k_lds_stage_offset) / Element_per_dword + padding + ((k_warp_buffer_load_n_id >> 3)*(32*17) + (k_warp_buffer_load_n_id & 7)*64);
BUFFER_LOAD_FUNC(k_lds, k_ptr, lds_offset, gvOffset_s, gvOffset_v);
}
}
}
#pragma once
#include "philox.cuh"
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_apply_mask(DataType tensor[(WARP_M/32)*(WARP_N/32)][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; //lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4)*2;
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const int col_idx = col_idx_base + vec_idx*8;
if (col_idx >= max_seqlen_k) {
// { printf("col_idx=%d, max_seqlen_k=%d\n", col_idx, max_seqlen_k);}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = -INFINITY; // 这里也可以用 pk 指令, 后续加上
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int BLOCK_ROW_STRIDE, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_apply_dropout(DataType tensor[(WARP_M/32)*(WARP_N/32)][4], const int max_seqlen_k, const int col_idx_offset_,
unsigned long long seed, unsigned long long offset, uint32_t p_dropout_in_8bits_value,
union_vec2_uint rowcol, uint32_t* dropout_debug_count) {
// static_assert(WARP_M == 32 and "For Dropout, only WARP_M=32 is supported yet!");
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4)*2;
// prepare 4 uint for 16 uint8
union_vec4_uint random_uint4;
for (int mi = 0; mi < (WARP_M / 32); ++mi, rowcol.u32.x += BLOCK_ROW_STRIDE) { // when WARP_M > 32, attention, block_row_idx is computed by BLOCK_M / 32 rather than BLOCK_M / WARP_M
#pragma unroll
for (uint32_t ni = 0; ni < (WARP_N/32); ++ni, ++rowcol.u32.y) {
// for each 16 elements, generate 16 int8 -> 4 u32
random_uint4.u32 = flash::philox(seed, rowcol.u64, offset);
int cnt = 0;
#pragma unroll
for(uint32_t min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(uint32_t vec_idx=0; vec_idx<4; vec_idx++) {
const int col_idx = col_idx_base + vec_idx*8;
if (col_idx < max_seqlen_k) {
#pragma unroll
for(uint32_t min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
uint32_t cur_pos = (min_tile_n*2 + min_tile_m) * 4 + vec_idx;
uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since DCU has no compare instructions with 8/16 bits
if (cur_rand >= p_dropout_in_8bits_value) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = 0x0;
++cnt;
}
}
}
}
}
#if 0
atomicAdd(dropout_debug_count, cnt);
if (threadIdx.x == 0) atomicAdd(dropout_debug_count + 1, 1);
#endif
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_apply_mask_causal(DataType tensor[(WARP_M/32)*(WARP_N/32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q); // attention, when max_seqlen_k == max_seqlen_q, vgpr can be reduced again
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const int col_idx = col_idx_base + vec_idx*8;
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <bool HasWSLeft=true, typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_apply_mask_local(DataType tensor[(WARP_M/32)*(WARP_N/32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const int col_idx = col_idx_base + vec_idx*8;
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ?
-INFINITY: tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx];
// else if constexpr (Has_alibi) {
// tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] += gAlibi * (col_idx - row_idx);
// }
// if (threadIdx.x == 0) {
// printf("tensor[%d][%d][%d] = %.9f | row_idx: %d, col_idx: %d, col_idx_limit_right: %d, col_idx_limit_left: %d\n",
// mi + ni*(WARP_M/32), min_tile_n*2 + min_tile_m, vec_idx, tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx],
// row_idx, col_idx, col_idx_limit_right, col_idx_limit_left);
// }
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_apply_alibi(DataType tensor[(WARP_M/32)*(WARP_N/32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, float gAlibi) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
const int row_idx = row_idx_base + min_tile_m;
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const int col_idx = col_idx_base + vec_idx*8;
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] += gAlibi * (col_idx - row_idx);
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void int8_kvcache_thread_reduce_max(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
summary[m_idx*2].f32[min_tile_m] = -INFINITY; //OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx*2].f32[min_tile_m] = op(summary[m_idx*2].f32[min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
summary_cur[m_idx*2].f32[min_tile_m] = summary[m_idx*2].f32[min_tile_m];
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx*2].f32[min_tile_m] = op(summary_cur[m_idx*2].f32[min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx*2].u64 = 0x0; // 可以更狠一点, 直接初始化成第一个 additem_pair, 但是貌似容易导致编译器出问题, 影响不大, 可以不加
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx*2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
summary[m_idx*2].f32[min_tile_m] = 0; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx*2].f32[min_tile_m] = op(summary[m_idx*2].f32[min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
} else {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx*2].u64 = summary[m_idx*2].u64;
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary_cur[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx*2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
summary_cur[m_idx*2].f32[min_tile_m] = summary[m_idx*2].f32[min_tile_m];
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx*2].f32[min_tile_m] = op(summary_cur[m_idx*2].f32[min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
}
}
template<typename Operator, typename DataType, int WARP_M>
__device__ inline void int8_kvcache_quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); mi++) {
dst[mi] = Allreduce<64>::run(src[mi], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void int8_kvcache_reduce_(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if constexpr (OpType == 0) { // sum
if constexpr (zero_init == true) {
int8_kvcache_thread_reduce_sum<true, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
int8_kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
int8_kvcache_thread_reduce_sum<false, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
int8_kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
} else if constexpr (OpType == 1) { // max
if constexpr (zero_init == true) {
int8_kvcache_thread_reduce_max<true, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
int8_kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
int8_kvcache_thread_reduce_max<false, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
int8_kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
}
}
//zero_init==true, max is current max_score, max_cur=nullptr
//zero_init==true, max is prev max_score, max_cur!=nullptr
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void int8_kvcache_reduce_max(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if constexpr (zero_init == true) {
int8_kvcache_reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op);
} else {
int8_kvcache_reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void int8_kvcache_reduce_sum(DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if constexpr (zero_init == true) {
int8_kvcache_reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op);
} else {
int8_kvcache_reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], const DataType1 *max, const float scale) {
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
const float max_scaled = (max[mi*2].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi*2].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
__float2 neg_max_scaled_pair = {-max_scaled, -max_scaled};
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
#else
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1, typename DataType2, int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int M_MMAC_COUNT>
inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/32)*(WARP_M/32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K/kBlockK) * ((WARP_M/32)*(kBlockK/32))][4], DataType2* max_lds, int WARP_ID, float softmax_scale_log2) {
if constexpr (Is_first) {
int8_kvcache_reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max);
int8_kvcache_scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, softmax_scale_log2);
int8_kvcache_reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum);
} else {
// float scores_max_cur[WARP_M/16]; //calculate max of each row
DataType1 scores_max_cur[(WARP_M/32)];
int8_kvcache_reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
int lane_id = threadIdx.x & 63;
if constexpr (WARP_NUM == 4) {
// 求 4 个 wave 最大值里的最大值
if(lane_id < 16) {
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
max_lds[WARP_ID*WARP_M + mi*32 + lane_id*2] = scores_max_cur[mi].f32[0];
} else {
*(__float2*)(max_lds + WARP_ID*WARP_M + mi*32 + lane_id*2) = scores_max_cur[mi].u64;
}
}
__syncthreads();
// 0 号 wave reduce 其他 wave 的最大值
if (WARP_ID == 0){
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 0) {
DataType2 tmp = max_lds[mi*32 + lane_id*2];
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++ ) {
tmp = max(tmp, max_lds[warp_loop*WARP_M + mi*32 + lane_id*2]);
}
max_lds[mi*32 + lane_id*2] = tmp;
} else {
__float2 cur_wave_max = *(__float2*)(max_lds + mi*32 + lane_id*2);
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++ ) {
__float2 other_warp_max = *(__float2*)(max_lds + warp_loop*WARP_M + mi*32 + lane_id*2);
cur_wave_max[0] = max(cur_wave_max[0], other_warp_max[0]);
cur_wave_max[1] = max(cur_wave_max[1], other_warp_max[1]);
}
*(__float2*)(max_lds + mi*32 + lane_id*2) = cur_wave_max;
}
}
}
}
__syncthreads();
// 4 个 wave 从 lds 读取最终 reduce 的最大值
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
scores_max_cur[mi].f32[0] = max_lds[mi*32 + (lane_id&15)*2];
} else {
scores_max_cur[mi].u64 = *(__float2*)(max_lds + mi*32 + (lane_id&15)*2);
}
}
// 等 4 个 wave 取完最大值, 因为后续还要写 max_lds
__syncthreads();
}
for (int mi = 0; mi < (WARP_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi*2].f32[min_tile_m]
: (scores_max_cur[mi*2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi*2].f32[min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi*2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi*2].f32[min_tile_m] *= scores_scale;
__float2 scores_scale_pair = {scores_scale, scores_scale};
// __float2 scores_scale_pair = {1, 1};
#pragma unroll
for(int pv_n_loop=0; pv_n_loop<(K/kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
}
#else
// 928 及之前的架构没 pk_mul 指令
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].f32[vec_idx] *= scores_scale;
}
#endif
}
}
}
}
}
int8_kvcache_scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max_cur, softmax_scale_log2);
DataType1 scores_sum_cur[(WARP_M/32)];
for (int mi = 0; mi < (WARP_M/32); ++mi) {
scores_sum_cur[mi].u64 = 0x0;
}
int8_kvcache_reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum_cur);
if constexpr (WARP_NUM == 4) {
// 重新求多个 wave 的归一化和
DataType2* sum_lds = max_lds;
if(lane_id < 16) {
// 每个 wave 的归一化和写到 lds
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
sum_lds[WARP_ID*WARP_M + mi*32 + lane_id*2] = scores_sum_cur[mi].f32[0];
} else {
*(__float2*)(sum_lds + WARP_ID*WARP_M + mi*32 + lane_id*2) = scores_sum_cur[mi].u64; // M_MMAC_COUNT doesn't exceed 2
}
}
__syncthreads();
// 0 号 wave reduce 其他 wave 的归一化和
if (WARP_ID == 0) {
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
float tmp = sum_lds[mi*32 + lane_id*2];
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++) {
tmp += sum_lds[warp_loop*WARP_M + mi*32 + lane_id*2];
}
sum_lds[mi*32 + lane_id*2] = tmp;
} else {
__float2 cur_wave_sum = *(__float2*)(sum_lds + mi*32 + lane_id*2);
#pragma unroll
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop*WARP_M + mi*32 + lane_id*2);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
#endif
}
*(__float2*)(sum_lds + mi*32 + lane_id*2) = cur_wave_sum;
}
}
}
}
__syncthreads();
// 4 个 wave 从 lds 中读取最后 reduce 的归一化和
for (int mi = 0; mi < (WARP_M/32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
scores_sum_cur[mi*2].f32[0] = sum_lds[mi*32 + (lane_id&15) * 2];
} else {
scores_sum_cur[mi*2].u64 = *(__float2*)(sum_lds + mi*32 + (lane_id&15) * 2);
}
}
__syncthreads(); // 以免后续的 buffer_load_to_lds 调度到这之前
}
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
// #######################################################
scores_max[mi].u64 = scores_max_cur[mi].u64;
#else
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
// #######################################################
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
#endif
}
}
};
template <int WARP_M, int WARP_N, int M_MMAC_COUNT, typename Element, typename ElementAccum>
inline __device__ void int8_kvcache_convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M/32)*(WARP_N/32)][4], union_vec4_fp32 s_reg[(WARP_M/32)*(WARP_N/32)][4]) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<M_MMAC_COUNT; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element, half_t>::value) {
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
// int lane_id = threadIdx.x & 63;
// printf("s_reg[%d][%d].f32[%d] is %f\n", n_idx*(WARP_M/32) + m_idx, 0*2 + min_tile_m, min_tile_k * 2 + 0, s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
// printf("s_reg[%d][%d].f32[%d] is %f\n", n_idx*(WARP_M/32) + m_idx, 1*2 + min_tile_m, min_tile_k * 2 + 0, s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
// printf("s_reg[%d][%d].f32[%d] is %f\n", n_idx*(WARP_M/32) + m_idx, 0*2 + min_tile_m, min_tile_k * 2 + 1, s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
// printf("s_reg[%d][%d].f32[%d] is %f\n", n_idx*(WARP_M/32) + m_idx, 1*2 + min_tile_m, min_tile_k * 2 + 1, s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
// p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
// 1);
// p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
// 1);
// p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
// 1);
// p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
// 1);
} else if constexpr (std::is_same<Element, bhalf_t>::value) {
#if 1
// for more effective inplementation, simplify __float2bfloat16 in HIP
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 0] = inlineasm_float2bfloat16_nonan(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 0] = inlineasm_float2bfloat16_nonan(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 1] = inlineasm_float2bfloat16_nonan(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 1] = inlineasm_float2bfloat16_nonan(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#else
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][0*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx*(WARP_M/32) + m_idx][1*2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
}
}
}
}
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int kHeadDim, int kBlockK, int WARP_M, int M_MMAC_COUNT, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void kvcache_acco_reduce(
vec4_Accum<ElementAccum> acc_o[(kHeadDim / kBlockK) * ((WARP_M / 32) * (kBlockK / 32))][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int WARP_ID,
int lane_id) {
// when REUSE_KV not in templated, compute max reuse times
int EVEN_REUSE_KV_TIMES = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2: ((seqlen_q + 1) / 2) * 2;
int HALF_REUSE_KV_TIMES = EVEN_REUSE_KV_TIMES >> 1;
int q_seq_idx = (lane_id & 15);
constexpr int __kHeadDim = (REUSE_KV_TIMES >= 16 or kHeadDim == 512) ? kHeadDim: kHeadDim + 4/*<=15 can use misalign to reduce bank conflicts, but >16 may lead to lds>32KB, less waves per SIMD*/;
if (q_seq_idx < HALF_REUSE_KV_TIMES) {
// ####################################################################################################################################################
// 4 个 wave 分别把自己负责的 acc_o 计算结果写到 LDS 中
for (int h_idx = 0; h_idx < (kHeadDim / kBlockK); h_idx++) {
for (int k_idx = 0; k_idx < (kBlockK / 32); k_idx++) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; min_tile_m++) {
for (int min_tile_k = 0; min_tile_k < 2; min_tile_k++) {
int lds_offset = WARP_ID * EVEN_REUSE_KV_TIMES * __kHeadDim + q_seq_idx * 2 * __kHeadDim + min_tile_m * __kHeadDim + h_idx * kBlockK + k_idx * 32 + min_tile_k * 16 + (lane_id >> 4) * 4;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx * ((WARP_M / 32) * (kBlockK / 32)) + k_idx * (WARP_M / 32)][min_tile_k * 2 + min_tile_m].f32;
}
}
}
}
__syncthreads();
// ####################################################################################################################################################
// 4 个 wave 共同参与 acc_o 在 LDS 中的相加
if constexpr (WARP_NUM == 4) {
for (int h_idx = 0; h_idx < (kHeadDim / kBlockK); h_idx++) {
for (int k_idx = 0; k_idx < (kBlockK / 32); k_idx++) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; min_tile_m++) {
union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 0 * 16 + (lane_id >> 4) * 4 + WARP_ID;
int lds_offset1 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 1 * 16 + (lane_id >> 4) * 4 + WARP_ID;
acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
union_vec2_fp32 acc_tmp_wave1;
acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2;
acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3;
acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
acc_tmp.f32[1] += acc_tmp_wave3.f32[1];
// ds_write2_b32
acc_o_lds[lds_offset0] = acc_tmp.f32[0];
acc_o_lds[lds_offset1] = acc_tmp.f32[1];
}
}
}
__syncthreads();
} else if constexpr (WARP_NUM > 1) {
if (WARP_ID == 0) {
for (int h_idx = 0; h_idx < (kHeadDim / kBlockK); h_idx++) {
for (int k_idx = 0; k_idx < (kBlockK / 32); k_idx++) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; min_tile_m++) {
for (int min_tile_k = 0; min_tile_k < 2; min_tile_k++) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
int lds_offset = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + min_tile_k * 16 + (lane_id >> 4) * 4 + vec_idx;
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; loop++) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * __kHeadDim];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
}
}
__syncthreads();
}
// ####################################################################################################################################################
// 每个 wave 都从 LDS 获取最终的求和结果
for (int h_idx = 0; h_idx < (kHeadDim / kBlockK); h_idx++) {
for (int k_idx = 0; k_idx < (kBlockK / 32); k_idx++) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; min_tile_m++) {
for (int min_tile_k = 0; min_tile_k < 2; min_tile_k++) {
int lds_offset = q_seq_idx * 2 * __kHeadDim + min_tile_m * __kHeadDim + h_idx * kBlockK + k_idx * 32 + min_tile_k * 16 + (lane_id >> 4) * 4;
acc_o[h_idx * ((WARP_M / 32)*(kBlockK / 32)) + k_idx * (WARP_M / 32)][min_tile_k * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
}
}
}
\ No newline at end of file
#pragma once
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, int Padding, typename ElementAccum>
__forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
#if defined(__gfx938__) || defined(__gfx946__)
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and Padding == 0); // Specialized optimizatio for headdim 128
#else
constexpr int OPT_FOR_HDIM128 = false; // keep same as origin for archs <= gfx936
#endif
if constexpr (OPT_FOR_HDIM128) {
// #######################################################################################################################################
// bank-conflicts free path, higher performance
// #######################################################################################################################################
constexpr int PREFETCH = WARP_NUM;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += PREFETCH) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int prefetch = 0; prefetch < PREFETCH; ++prefetch) {
vec4_fp32 f32x4 = acc_o[k_loop + prefetch][min_tile_n * 2].f32;
int lds_write_offset = warp_id * 2048 + prefetch * 2 * 16 * 16 + min_tile_n * 16 * 16;
lds_write_offset = reinterpret_cast<size_t>(acc_o_lds + lds_write_offset + lane_id * 4);
inlineasm_ds_write_b128(lds_write_offset, f32x4);
}
}
union_vec4_fp32 data[2][WARP_NUM];
constexpr int ds_bursts = PREFETCH;
{
constexpr int min_tile_n = 0;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 1;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH + ds_bursts);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 0;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor + ds_bursts);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
{
constexpr int min_tile_n = 1;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
flash::wait_all_warp_arrived();
}
} else {
constexpr int kBlockK = K_WARP_COUNT * 32 + Padding;
// when REUSE_KV not in templated, compute max reuse times
int EVEN_REUSE_KV_TIMES = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2: ((seqlen_q + 1) / 2) * 2;
int q_seq_idx = (lane_id & 15);
if (q_seq_idx < EVEN_REUSE_KV_TIMES) {
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
// ####################################################################################################################################################
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int lds_offset = (warp_id * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT + q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4/*0~3*/) * 4/*0~15*/;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32;
}
}
}
__syncthreads();
// 在 lds 中求和, 把 4 个 wave 写的 acc_o 的数据加起来
if constexpr (WARP_NUM == 4) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + warp_id; // 之前是一次性写了 4 个 Half 到 lds, 现在 4 个 wave 分别处理这 4 个位置的 acc_o reduce
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
// 不是恰好 4 个 wave, 则把 wave 0 单独拎出来做 lds reduce 操作
else if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + vec_idx;
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
}
}
__syncthreads();
// 每个 wave 都从 LDS 获取最终的求和结果
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4;
acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
__syncthreads();
}
}
}
}
\ No newline at end of file
#pragma once
#include "numeric_types.h"
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void kvcache_epilugue_rescale_acco(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT]) {
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
__float2 scale_pair = {inv_sum, inv_sum};
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
}
#else
for (int vec_id = 0; vec_id < 4; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].f32[vec_id] *= inv_sum;
}
#endif
}
}
}
}
}
}
template<bool Split, bool Is_16x32, int M_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void kvcache_epilogue_store_max_sum(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum *scores_max_ptr,
ElementAccum *scores_sum_ptr,
ElementAccum scale_softmax,
int warp_id,
int thread_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit
) {
#ifdef FA_DEBUG_SUM_MAX
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = Is_16x32
? /*warp_id * WARP_M + */mi * 32 + lane_id/*equal to lane_id & 15*/ + min_tile_m * 16
: warp_id * M_WARP_COUNT * 32 + mi * 32 + thread_id * 2 + min_tile_m;
if (row < seqlen_q_limit) {
scores_sum_ptr[row] = scores_sum[mi].f32[min_tile_m];
scores_max_ptr[row] = scores_max[mi].f32[min_tile_m] * scale_softmax;
}
}
}
}
}
}
template<bool Split, bool Is_16x32, int M_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void kvcache_varlen_epilogue_store_max_sum(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum *scores_max_ptr,
ElementAccum *scores_sum_ptr,
ElementAccum scale_softmax,
int warp_id,
int thread_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit,
int total_q,
int ngroups
) {
#ifdef FA_DEBUG_SUM_MAX
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = Is_16x32
? /*warp_id * WARP_M + */mi * 32 + lane_id/*equal to lane_id & 15*/ + min_tile_m * 16
: warp_id * M_WARP_COUNT * 32 + mi * 32 + thread_id * 2 + min_tile_m;
if (row < seqlen_q_limit) {
int row_target = (row / ngroups) + (row % ngroups) * total_q;
scores_sum_ptr[row_target] = scores_sum[mi].f32[min_tile_m];
scores_max_ptr[row_target] = scores_max[mi].f32[min_tile_m] * scale_softmax;
}
}
}
}
}
}
template<bool Is_Varlen, bool Is_16x32, int M_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void kvcache_epilogue_store_softmax_lse(
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum *softmax_lse_ptr,
ElementAccum scale_softmax,
int warp_id,
int thread_id,
int lane_id,
int headdim_split_id,
int seqlen_q_limit,
int total_q,
int ngroups
) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum lse = scores_max[mi].f32[min_tile_m] * scale_softmax + __logf(scores_sum[mi].f32[min_tile_m]);
const int row = Is_16x32
? /*warp_id * WARP_M + */mi * 32 + lane_id/*equal to lane_id & 15*/ + min_tile_m * 16
: warp_id * M_WARP_COUNT * 32 + mi * 32 + thread_id * 2 + min_tile_m;
if (row < seqlen_q_limit) {
int row_target = Is_Varlen ? (row / ngroups) + (row % ngroups) * total_q: row;
softmax_lse_ptr[row_target] = lse;
}
}
}
}
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Split, bool Is_16x32, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_epilogue_store_output(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_o = bidb * int64_t(params.o_batch_stride) + bidh * params.o_head_stride + headdim_split_id * kHeadDimVSplit;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * params.b * params.o_batch_stride
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
// acquire tile 32x32 id
int tile_32x32_id = k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// seqlen_q offset
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m);
if (seqlen_q_idx < params.seqlen_q) {
if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__) || defined(__gfx946__)
int vec_index = warp_id;
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2;
vec2_Element<SplitkvAccumType> result = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[vec_index], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
*(vec2_Element<SplitkvAccumType>*)(o_ptr + pv_global_addr) = result;
#else
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_m + min_tile_n * 2;
int vec_index = warp_id;
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
#endif
} else { // non-4-waves should use this, but lead to performance drop when 4 waves per SIMD
#pragma unroll
for (int vec_index = 0; vec_index < 4; ++vec_index) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 当前 32x32 tile 的第几个 mmac
int mmac_id = min_tile_m + min_tile_n * 2;
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
}
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)");
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Split, bool Is_16x32, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_varlen_epilogue_store_output(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
Params params,
int64_t row_offset_o,
int actual_seqlen_q,
int bidb,
int bidh,
int m_block,
int split_id,
int headdim_split_id,
int warp_id,
int lane_id) {
int output_seqlen_stride = params.o_row_stride;
const int64_t row_offset_split = params.ngroups * int64_t(params.total_q) * params.o_row_stride;
SplitkvAccumType* o_ptr = Split
? reinterpret_cast<SplitkvAccumType *>(params.oaccum_ptr) + row_offset_o + /*which split*/ split_id * row_offset_split
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < M_WARP_COUNT; ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < K_WARP_COUNT; ++k_tile_idx) {
// acquire tile 32x32 id
int tile_32x32_id = k_loop * M_WARP_COUNT * K_WARP_COUNT + warp_m_idx * K_WARP_COUNT + k_tile_idx;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// seqlen_q offset
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m);
if (seqlen_q_idx < actual_seqlen_q) {
if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__) || defined(__gfx946__)
int vec_index = warp_id;
int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups;
int64_t pv_global_addr = true_seqlen_q * params.ngroups * output_seqlen_stride + true_group_id * params.o_head_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2;
vec2_Element<SplitkvAccumType> result = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[vec_index], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
*(vec2_Element<SplitkvAccumType>*)(o_ptr + pv_global_addr) = result;
#else
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_m + min_tile_n * 2;
int vec_index = warp_id;
int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups;
int64_t pv_global_addr = true_seqlen_q * params.ngroups * output_seqlen_stride + true_group_id * params.o_head_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
#endif
} else { // non-4-waves should use this, but lead to performance drop when 4 waves per SIMD
#pragma unroll
for (int vec_index = 0; vec_index < 4; ++vec_index) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 当前 32x32 tile 的第几个 mmac
int mmac_id = min_tile_m + min_tile_n * 2;
int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups;
int64_t pv_global_addr = true_seqlen_q * params.ngroups * output_seqlen_stride + true_group_id * params.o_head_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2 + min_tile_n;
ElementAccum data = acc_o[tile_32x32_id][mmac_id].f32[vec_index];
o_ptr[pv_global_addr] = DownCast<ElementAccum, SplitkvAccumType>(data);
}
}
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)");
}
\ No newline at end of file
#include "kvcache_pv_gemm_prefetch_k_3stage.h"
template<bool PREFETCH_K, int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_pv_gemm_prefetch_k(
vec4_uint gV,
vec4_uint gK,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (WARP_K / 32)][4],
vec4_Accum<ElementAccum> pv_reg[(kHeadDim / kBlockN) * (WARP_M / 32) * (kBlockN / 32)][4],
int WARP_ID,
int vcache_seqlen_stride,
int max_seq_kv_offset = -1) {
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
union_vec2_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32)][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = laneid_shfl_2 & 0xc; // 第几个4线程组的最小id
int tail = laneid_shfl_2 & 0x3; // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + WARP_ID * STAGES * WARP_K * kBlockN;
constexpr int N_LOOP_START = (STAGES == 2) ? 1: 0;
int stage_id = (STAGES == 2) ? 1: 0;
for (int n_loop = N_LOOP_START; n_loop < (kHeadDim / kBlockN); ++n_loop) {
int v_block_buffer_load_global_offset = WARP_ID * WARP_K * vcache_seqlen_stride + n_loop * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = /* (v_warp_buffer_load_n_id * 32) + */load * READ_ONCE_COUNT;
int s_offset = v_block_buffer_load_global_offset / 2;
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + stage_id * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
if constexpr (STAGES == 2) stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * (WARP_N * WARP_K) / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * (WARP_N * WARP_K) / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * (WARP_N * WARP_K) / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop - 1: n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if constexpr (STAGES == 2) {
int n_loop = (kHeadDim / kBlockN) - 1;
stage_id ^= 1;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<0>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * (WARP_N * WARP_K) / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * (WARP_N * WARP_K) / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * (WARP_N * WARP_K) / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#include "kvcache_pv_gemm_utils.h"
template<bool PREFETCH_K, int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_3stage(
vec4_uint gV,
vec4_uint gK,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (WARP_K / 32)][4],
vec4_Accum<ElementAccum> pv_reg[(kHeadDim/kBlockN) * (WARP_M / 32) * (kBlockN / 32)][4],
int WARP_ID,
int vcache_seqlen_stride,
int max_seq_kv_offset = 0) {
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
union_vec2_f16x2<Element> v_reg[STAGES * (32 * WARP_N) / (32 * 32)][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = laneid_shfl_2 & 0xc; // 第几个4线程组的最小id
int tail = laneid_shfl_2 & 0x3; // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + WARP_ID * STAGES * WARP_K * kBlockN;
constexpr int N_LOOP_START = 0;
// for (int n_loop = N_LOOP_START; n_loop < (kHeadDim / kBlockN); ++n_loop)
int n_loop = 0;
int stage_id = 0;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync < 2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
// #pragma unroll
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * WARP_N * WARP_K / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 1;
// stage_id = 1;
{
{
stage_id = 0;
int v_block_buffer_load_global_offset = WARP_ID * vcache_seqlen_stride * WARP_K + (n_loop + 2/*now, n_loop = 1 rather than 0*/) * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = /* (v_warp_buffer_load_n_id * 32) + */load * READ_ONCE_COUNT;
int s_offset = v_block_buffer_load_global_offset / 2;
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + stage_id * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
}
stage_id = 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync < 2 * V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * WARP_N * WARP_K / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 2;
stage_id = 2;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
// #pragma unroll
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * WARP_N * WARP_K / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
n_loop = 3;
stage_id = 0;
{
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
// lds -> vgpr use ds_read_m; right matrix
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2 * 4;
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for(int seq_idx = 0; seq_idx < (WARP_K / 32); ++seq_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (WARP_N / 32); ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * WARP_N * WARP_K / (32 * 32) + head_dim_idx * (WARP_K / 32) + seq_idx][vec_idx].u64, NEXT_DWORD_OFFSET);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[(32 * WARP_N) / (32 * 32)][2];
{
constexpr int min_tile_k = 0;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(2)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr int min_tile_k = 1;
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0],
p_reg[0][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1],
p_reg[0][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1]
);
}
asm volatile("s_setprio 1");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
const int v_tile_id = stage_id * WARP_N * WARP_K / (32 * 32) + n_idx * (WARP_K / 32) + k_idx;
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][0 + min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][1 + min_tile_k * 2].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int k_idx = 0; k_idx < (WARP_K / 32); ++k_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * (WARP_M / 32) * (kBlockN / 32) + n_idx * (WARP_M / 32) + m_idx;
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 =
flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * (WARP_K / 32) + k_idx][/*min_tile_k * 2 + */min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#include "kvcache_pv_gemm_utils_tile16x32.h"
template<int K_LOOP_COUNT, int kBlockM, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert(kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert(kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32");
union_vec2_f16x2<Element> v_reg[STAGES * PV_K_WARP_COUNT * PV_N_WARP_COUNT][4];
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
constexpr int NEXT_DWORD_OFFSET = 64; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 第二个 dword 偏移 64 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = laneid_shfl_2; // global -> lds, seqlen 方向的坐标
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 按照线程 [0, 16, 32, 48] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = laneid_shfl_4; // 0-15 read row 0; 16-31 read row 1; 32-47 read row 2; 48-63 read row 3
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
// each wave need 2 32x32 lds space
v_lds = v_lds + warp_id * STAGES * WARP_K * kBlockN;
int stage_id = (STAGES == 2) ? 1: 0;
constexpr int N_LOOP_START = (STAGES == 2) ? K_LOOP_COUNT - 2: K_LOOP_COUNT - 1;
for (int n_loop = N_LOOP_START; n_loop >= 0; --n_loop) {
int v_block_buffer_load_global_offset = n_loop * kBlockN;
#pragma unroll
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES + warp_id * WARP_K, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<V_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][min_tile_k * 2 + min_tile_m].f16[0],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[1],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[2],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2 + 1].f16x2[0][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2 + 1].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = (STAGES == 2) ? n_loop + 1: n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if constexpr (STAGES == 2) {
int n_loop = 0;
stage_id ^= 1;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int precompute_v_lds_offset[4];
vec2_Element<Element> *v_lds_v2fp16 = (vec2_Element<Element> *)(v_lds);
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<0>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
#pragma unroll
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element<Element> v_vgprs[PV_K_WARP_COUNT * PV_N_WARP_COUNT][2];
// ds 和 vgpr 之间的 ping-pang buffer
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
vec4_Element<Element> p_vgprs[2];
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
p_vgprs[min_tile_m] = make_vec4_f16(
p_reg[0][min_tile_k * 2 + min_tile_m].f16[0],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[1],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[2],
p_reg[0][min_tile_k * 2 + min_tile_m].f16[3]
);
}
asm volatile("s_setprio 1");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int v_tile_id = stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + n_idx * PV_K_WARP_COUNT + k_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n] = vec4_Element<Element>{
v_reg[v_tile_id][min_tile_k * 2].f16x2[0][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2].f16x2[1][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2 + 1].f16x2[0][min_tile_n],
v_reg[v_tile_id][min_tile_k * 2 + 1].f16x2[1][min_tile_n]
};
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < M_WARP_COUNT; ++m_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < PV_K_WARP_COUNT; ++k_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < PV_N_WARP_COUNT; ++n_idx) {
int n_loop_idx = n_loop;
int pv_tile_id = n_loop_idx * M_WARP_COUNT * PV_N_WARP_COUNT + n_idx * M_WARP_COUNT + m_idx;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
p_vgprs[min_tile_m],
v_vgprs[n_idx * PV_K_WARP_COUNT + k_idx][min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
asm volatile("s_setprio 0");
}
}
__syncthreads(); // here, K/V use more lds, and thus reuse togather, need sync
}
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void kvcache_prefetch_v_to_lds(
vec4_uint gV,
Element* v_lds,
int WARP_ID,
int vcache_seqlen_stride,
int max_seq_kv_offset=0) {
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_3 = lane_id >> 3; // 0 ~ 7, 8 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
int laneid_shfl_5 = lane_id >> 5; // 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr int NEXT_DWORD_OFFSET = 32; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 一次读几行数据, loadx2, 每行 32 个元素需要 32 / (4) = 8 个线程
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 一次 load 多少数据
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 整个 workgroup 要发多少 load 指令
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 每个 warp 要发多少 load 指令
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = laneid_shfl_2 & 0xc; // 第几个4线程组的最小id
int tail = laneid_shfl_2 & 0x3 ; // 线程组中的第几个线程
int v_lane_seq_k_idx = base + (tail & 1) * 2 + (tail >> 1);// 每个线程负责读取第几行数据的 4 个 dword
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = (laneid_shfl_4 & 1) * 2 + laneid_shfl_5; // 0, 1, 2, 3 ---> 0, 2, 1, 3
int v_ds_read_offset = (laneid_shfl_5 * 4 + (laneid_shfl_4 & 1)) * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
int n_loop = 0;
if constexpr (STAGES > 1) {
int v_block_buffer_load_global_offset = WARP_ID * WARP_K * vcache_seqlen_stride + n_loop * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) + */load * READ_ONCE_COUNT;
int s_offset = v_block_buffer_load_global_offset / 2;
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * STAGES * WARP_K * kBlockN + (stage_id) * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES > 2) {
{
int v_block_buffer_load_global_offset = WARP_ID * WARP_K * vcache_seqlen_stride + (n_loop + 1) * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) + */load * READ_ONCE_COUNT;
int s_offset = v_block_buffer_load_global_offset / 2;
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * 3 * WARP_K * kBlockN + (stage_id + 1) * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
{
int v_block_buffer_load_global_offset = WARP_ID * WARP_K * vcache_seqlen_stride + (n_loop + 2) * kBlockN;
for(int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_k_id = (load + WARP_ID) % V_LOAD_REQUESTS;
int v_warp_buffer_load_lds_offset = /*(v_warp_buffer_load_n_id * 32) + */load * READ_ONCE_COUNT;
int s_offset = v_block_buffer_load_global_offset / 2;
int v_offset = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES, max_seq_kv_offset - 1) * vcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + WARP_ID * 3 * WARP_K * kBlockN + (stage_id + 2) * WARP_K * kBlockN, gV, v_lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
}
}
\ No newline at end of file
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
__forceinline__ __device__ void kvcache_prefetch_v_to_lds_tile16x32(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_kv_offset=-1) {
// 预先计算一些公共表达式
int lane_id = threadIdx.x & 63;
int laneid_shfl_2 = lane_id >> 2; // 0 ~ 15, 4 个线程读取一行
int laneid_shfl_4 = lane_id >> 4; // 0 ~ 3, 16 个线程读取一行
constexpr int NEXT_DWORD_OFFSET = 64; // 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 第二个 dword 偏移 64 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr int READ_ONCE_LINES = 16; // 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32; // 一个 warp 每次 load 多少数据, 16x32
constexpr int V_LDS_LOAD_NUM = kBlockN * WARP_K / READ_ONCE_COUNT; // 一个 warp 一共要发几次读取请求
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM; // 一个 warp 一共要发几次读取请求
constexpr int READ_ELEMENT_COUNT = 8; // 每个线程一次读取几个 Half
int v_lane_headdim_n_idx = lane_id & 3; // 当前 lane 负责这个 warp 的第几个 dwordx2
int base = (laneid_shfl_2 & 0xc); // 第几个 4 线程组的最小id
int tail = (laneid_shfl_2 & 0x3); // 4 线程组中的第几个线程
int v_lane_seq_k_idx = laneid_shfl_2; // global -> lds, seqlen 方向的坐标
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2; // 一次读写 8x32, 按照线程 [0, 16, 32, 48] 这种方式来读取
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
#else
constexpr int READ_ONCE_LINES = 4;
constexpr int READ_ONCE_COUNT = READ_ONCE_LINES * 32;
constexpr int V_LDS_LOAD_NUM = (kBlockN * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM;
constexpr int READ_ELEMENT_COUNT = 2;
int v_lane_headdim_n_idx = lane_id & 15;
int v_lane_seq_k_idx = laneid_shfl_4; // 0-15 read row 0; 16-31 read row 1; 32-47 read row 2; 48-63 read row 3
int v_ds_read_offset = laneid_shfl_4 * 32 + (lane_id & 15) * 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
#endif
int n_loop = (kHeadDim / kBlockN) - 1;
if constexpr (STAGES > 1) {
int v_block_buffer_load_global_offset = n_loop * kBlockN;
#pragma unroll
for (int load = 0; load < V_LOAD_REQUESTS; ++load) {
int v_warp_buffer_load_lds_offset = load * READ_ONCE_COUNT;
int v_gvoffset_s = v_block_buffer_load_global_offset / 2;
int v_gvoffset_v = (v_lane_headdim_n_idx * READ_ELEMENT_COUNT + min(v_lane_seq_k_idx + load * READ_ONCE_LINES + warp_id * WARP_K, max_seq_kv_offset - 1) * kvcache_seqlen_stride) / 2;
int v_lds_offset = v_warp_buffer_load_lds_offset / 2;
BUFFER_LOAD_FUNC(v_lds + warp_id * STAGES * WARP_K * kBlockN + stage_id * WARP_K * kBlockN, v_addr, v_lds_offset, v_gvoffset_s, v_gvoffset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "kvcache_qk_gemm_prefetch_v_3stage.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_qk_gemm_prefetch_v(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int WARP_ID,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset=0) {
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
union_vec4_f16x2<Element> k_reg[STAGES * (WARP_N * kBlockK) / (32 * 32) * 2];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = WARP_ID & (WARP_N / WARP_N - 1);
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
constexpr int k_lds_load_num = WARP_N * kBlockK / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
constexpr int ELEMENT_BYTES = sizeof(Element);
int stage_id = 0;
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg);
} else {
inline_vgpr4_init_zero_1x4x4(s_reg);
}
__builtin_amdgcn_sched_barrier(0);
#else
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
#endif
constexpr int K_LOOP_START = (STAGES == 2) ? 1: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for(int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK + WARP_ID * WARP_N * kcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 34) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1) * kcache_seqlen_stride / ELEMENT_BYTES + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^=1;
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
// stage_id ^= 1;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2 * 2];
if constexpr (true) {
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + last_stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = ((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx< (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = ((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = last_stage_id * (WARP_N * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
kvcache_prefetch_v_to_lds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(gV, v_lds, WARP_ID, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "kvcache_qk_gemm_utils.h"
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_3stage(
vec4_uint gQ,
vec4_uint gK,
vec4_uint gV,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int WARP_ID,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset = 0) {
static_assert (kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
union_vec4_f16x2<Element> k_reg[STAGES * WARP_N * kBlockK / (32 * 32) * 2];
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
int k_warp_n_id = WARP_ID & (WARP_N / WARP_N - 1);
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 17) + (lane_id & 1) * 16 + (laneid_and_15 >> 1) * 65 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32);
constexpr int k_lds_load_num = WARP_N * kBlockK / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
constexpr int ELEMENT_BYTES = sizeof(Element);
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg);
} else {
inline_vgpr4_init_zero_1x4x4(s_reg);
}
__builtin_amdgcn_sched_barrier(0);
#else
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
#endif
int k_loop = 0;
int stage_id = 0;
// for(int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop++)
{
stage_id = 0;
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
k_loop = 1;
{
{
stage_id = 0;
int k_block_buffer_load_global_offset = (k_loop + 2) * kBlockK + WARP_ID * WARP_N * kcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 34) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1) * kcache_seqlen_stride / ELEMENT_BYTES + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
inline_buffer_load_dword_lds(k_lds, gK, lds_offset, s_offset, v_offset);
}
}
stage_id = 1;
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
k_loop = 2;
stage_id = 2;
{
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
k_loop = 3;
stage_id = 0;
{
// 在 wait 之前提前计算这部分偏移量
int precompute_k_lds_offset[2 * 2];
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 17) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for(int i = 0; i < 2; ++i) {
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
for(int j = 0; j < 2; ++j) {
precompute_k_lds_offset[i * 2 + j] = reinterpret_cast<size_t>(k_lds_v2fp16) + (k_lds_stage_offset + head_dim_idx * (WARP_N * 17) + n_idx * (32 * 17) + j * 4 + i * 32 + k_ds_read_offset) * 4;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0>();
__builtin_amdgcn_sched_barrier(0);
if constexpr (true) {
#pragma unroll
for(int i = 0; i < 2; ++i) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int j = 0; j < 2; ++j) {
int lds_offset = precompute_k_lds_offset[i * 2 + j];
inline_ds_read2_b32_no_wait_bytes(lds_offset, k_reg[stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].u64[j], 2);
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(2)");
#else
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
{
int min_tile_n = 0;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
#ifdef USE_DS_OVERLAP_MMAC
asm volatile("s_waitcnt lgkmcnt(0)");
#endif
__builtin_amdgcn_sched_barrier(0);
{
int min_tile_n = 1;
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = stage_id * WARP_N * kBlockK / (32 * 32) * 2 + (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
kvcache_prefetch_v_to_lds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(gV, v_lds, WARP_ID, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "kvcache_qk_gemm_utils_tile16x32.h"
#define USE_DS_OVERLAP_MMAC
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int STAGES, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_tile16x32(
vec4_uint q_addr,
vec4_uint k_addr,
vec4_uint v_addr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int kcache_seqlen_stride,
int vcache_seqlen_stride,
int max_seq_k_offset=-1) {
static_assert(kBlockK == 32 and "To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on");
union_vec4_f16x2<Element> k_reg[1 * (WARP_N * kBlockK) / (32 * 32) * 2];
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 16;
#else // zd
int qk_lane_m_idx = laneid_shfl_4;
int qk_lane_head_dim_idx = laneid_and_15;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 4;
#endif
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (READ_ONCE_LINES * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
int k_warp_n_id = (warp_id & (WARP_N / WARP_N - 1));
// 0,0,32,32,0,0,32,32 | 0,0,32,32,0,0,32,32 | 16,16,48,48,16,16,48,48 | 16,16,48,48,16,16,48,48
// (lane_id & 1) * 16: in seqlen direction, [0,1,0,1,2,3,2,3], odd threads need skip 32 Halfs, 16 dword
// (laneid_and_15 >> 1) * 64: threads 0,1 occupy 4 lines, 4x32 Halfs, 64 dword.... 2,3 and 4,5 and 6,7 is the same
// laneid_and_15 >> 1, padding
// (laneid_shfl_4 & 1) * 8: threads 0,32 is even times of 16, thus 0,32; threads 16,48 is odd times of 16, thus 0,32,16,48; 0->16 need skip 16 Halfs, 8 dword
// (lane_id / 32): 0,0,32,32,0,0,32,32, 0->32, skip 2 Halfs, 1 dword
int k_ds_read_offset = k_warp_n_id * (WARP_N / 32) * (32 * 16) + laneid_and_15 * 16 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32) * 4;
// 初始化 s
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[i][min_tile_n * 2 + min_tile_m].u64[0] = pk_zero;
s_reg[i][min_tile_n * 2 + min_tile_m].u64[1] = pk_zero;
}
}
}
int stage_id = 0;
constexpr int K_LOOP_START = (STAGES == 2) ? 2: 0;
if constexpr (STAGES == 2) stage_id ^= 1;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); k_loop += 2) {
if constexpr (true) {
int k_block_buffer_load_global_offset = k_loop * kBlockK/*offset in headdim direction*/;
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32));
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
if constexpr (true) {
int k_block_buffer_load_global_offset = (k_loop + 1) * kBlockK/*offset in headdim direction*/;
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + ((k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32));
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
// 在 wait 之前提前计算这部分偏移量
if constexpr (STAGES == 2) stage_id ^= 1;
int precompute_k_lds_offset[2];
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * WARP_N * 16 + n_idx * 32 * 16 + i * 16 * 16 + k_ds_read_offset;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<3 * K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 2: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
{
int precompute_k_lds_offset[2];
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * WARP_N * 16 + n_idx * 32 * 16 + i * 16 * 16 + k_ds_read_offset;
}
}
}
__builtin_amdgcn_sched_barrier(0);
#ifdef USE_PINGPANG_BUFFER
if constexpr (STAGES == 2) {
buffer_load_lds_dwordx1_wait_nosync<2 * K_LOAD_REQUESTS>();
} else if constexpr (STAGES == 1) {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
#else
buffer_load_lds_dwordx1_wait_nosync<0>();
#endif
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int k_loop_idx = (STAGES == 2) ? k_loop - 1: k_loop;
int q_tile_id = k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
// 保留第 1 阶段最后一波数据实际的 stage_id
int last_stage_id = stage_id ^ 1;
// 等待第 1 阶段最后一波数据返回做计算
if constexpr (STAGES == 2) {
constexpr int k_loop = kHeadDim / kBlockK;
// 在 wait 之前提前计算好 lds load 的下标
int precompute_k_lds_offset[2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (warp_id * STAGES * 2 + last_stage_id * 2) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * (WARP_N * 16) + n_idx * (32 * 16) + i * 16 * 16 + k_ds_read_offset;
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN >= (WARP_N * 2)) {
buffer_load_lds_dwordx1_wait_nosync<1 * K_LOAD_REQUESTS>();
} else {
buffer_load_lds_dwordx1_wait_nosync<0>();
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - 2) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
{
int precompute_k_lds_offset[2];
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
int k_lds_stage_offset = (warp_id * STAGES * 2 + last_stage_id * 2 + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 16);
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
precompute_k_lds_offset[i] = k_lds_stage_offset + head_dim_idx * (WARP_N * 16) + n_idx * (32 * 16) + i * 16 * 16 + k_ds_read_offset;
}
}
}
// 等待最后一波数据的返回
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_dwordx1_wait_nosync<0 * K_LOAD_REQUESTS>();
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
for (int i = 0; i < 2; ++i) {
k_reg[(head_dim_idx * (WARP_N / 32) + n_idx) * 2 + i].f32 = *(vec4_fp32*)(k_lds_v2fp16 + precompute_k_lds_offset[i]);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 32; ++n_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = (k_loop - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + min_tile_m;
int k_tile_id = (head_dim_idx * (WARP_N / 32) + n_idx) * 2 + min_tile_n;
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
vec4_Element<Element>{q_reg[q_tile_id].f16x2[2 * min_tile_k][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k][1],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][0],
q_reg[q_tile_id].f16x2[2 * min_tile_k + 1][1]},
vec4_Element<Element>{k_reg[k_tile_id].f16x2[2 * min_tile_k][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k][1],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][0],
k_reg[k_tile_id].f16x2[2 * min_tile_k + 1][1]},
s_reg[n_idx * (WARP_M / 32) + m_idx][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if constexpr (STAGES > 1) {
kvcache_prefetch_v_to_lds_tile16x32<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, 32/*WARP_K*/, 0, WARP_NUM, Element, STAGES>(v_addr, v_lds, warp_id, vcache_seqlen_stride, max_seq_k_offset);
}
} // qk_gemm
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "kvcache_pv_gemm_utils.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int REUSE_KV_TIMES, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_prefetch_q_to_vgpr(
vec4_uint gQ,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
int WARP_ID,
int kvcache_prefetch_q_to_vgpr,
int max_seq_q_offset=0) {
constexpr bool Is_GQA = M_MMAC_COUNT > 1;
constexpr int Q_LOAD_REQUESTS = (REUSE_KV_TIMES == 0)
? kBlockM * kBlockK / (4 * 32) / WARP_NUM
: Is_GQA ? ((REUSE_KV_TIMES + 1) >> 1) << 2 / WARP_NUM: 1/*MHA only need the first token*/;
constexpr int SEQUENCE_READ = Is_GQA ? 2: 1;
constexpr int ELEMENT_BYTES = sizeof(Element);
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = ((lane_id >> 4) & 1) * 2 + ((lane_id >> 4) >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int q_lane_head_dim_idx = lane_id & 15;
int stage_id = 0;
if constexpr (STAGES > 1) {
int k_loop = 0;
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2;
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int s_offset = q_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (v_offset < max_seq_q_offset) {
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
v_offset = (v_offset * kvcache_prefetch_q_to_vgpr) / ELEMENT_BYTES + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, s_offset, v_offset);
}
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
for(int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0, warp_loop = WARP_ID; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7) * 2;
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / 4 - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 34) + (q_warp_buffer_load_m_id & 7) * (4 * 32);
int s_offset = q_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = q_warp_buffer_load_m_id * 4 + q_lane_m_idx;
if (v_offset < max_seq_q_offset) {
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
v_offset = (v_offset * kvcache_prefetch_q_to_vgpr) / ELEMENT_BYTES + q_lane_head_dim_idx;
builtin_buffer_load_dword_lds(q_lds, gQ, lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for(int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + j * 2 + i * 32 + (lane_id & 1) * 16 + ((lane_id & 15) >> 1) * 64 + /*padding*/ ((lane_id & 15) >> 1) + ((lane_id / 16) & 1) * 8 + (lane_id / 32);
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f16x2[j]);
}
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 17);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for(int head_dim_idx = 0; head_dim_idx < (kBlockK / 32); ++head_dim_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int i = 0; i < SEQUENCE_READ; ++i) {
#pragma unroll
for(int j = 0; j < 4; ++j) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 17 + j * 2 + i * 32 + (lane_id & 1) * 16 + ((lane_id & 15) >> 1) * 64 + /*padding*/ ((lane_id & 15) >> 1) + ((lane_id / 16) & 1) * 8 + (lane_id / 32);
inline_ds_read_b32_wait(q_lds_v2fp16, lds_offset, q_reg[(kHeadDim / kBlockK - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f16x2[j]);
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
}
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void kvcache_prefetch_k_to_lds(
vec4_uint k_ptr,
Element* k_lds,
int WARP_ID,
int kvcache_seqlen_stride,
int max_seq_k_offset=0) {
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int qk_lane_m_idx = (laneid_shfl_4 & 1) * 2 + (laneid_shfl_4 >> 1); // (0, 1, 2, 3) --> (0, 2, 1, 3)
int qk_lane_head_dim_idx = laneid_and_15;
constexpr int k_lds_load_num = WARP_N * kBlockK / (4 * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
constexpr int ELEMENT_BYTES = sizeof(Element);
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
int stage_id = 0;
int k_loop = 0;
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = k_loop * kBlockK + WARP_ID * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 34) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = (min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kvcache_seqlen_stride / ELEMENT_BYTES + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
BUFFER_LOAD_FUNC(k_lds, k_ptr, lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES > 2) {
stage_id = 1;
int k_block_buffer_load_global_offset = (k_loop + 1) * kBlockK + WARP_ID * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 34) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / 2;
int v_offset = (min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kvcache_seqlen_stride / ELEMENT_BYTES + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
BUFFER_LOAD_FUNC(k_lds, k_ptr, lds_offset, s_offset, v_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES > 2) {
stage_id = 2;
int k_block_buffer_load_global_offset = (k_loop + 2) * kBlockK + WARP_ID * WARP_N * kvcache_seqlen_stride;
int k_lds_stage_offset = WARP_ID * (WARP_N / 32) * (kBlockK / 32) * (32 * 34) + stage_id * WARP_NUM * (WARP_N / 32) * (kBlockK / 32) * (32 * 34);
for(int load = 0; load < K_LOAD_REQUESTS; ++load) {
int __load = (load + WARP_ID) % K_LOAD_REQUESTS;
int padding = (__load & 7) * 2;
int k_warp_buffer_load_n_id = __load & (WARP_N / 4 - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 34) + (k_warp_buffer_load_n_id & 7) * (4 * 32);
int s_offset = k_block_buffer_load_global_offset / ELEMENT_BYTES;
int v_offset = (min(k_warp_buffer_load_n_id * 4 + qk_lane_m_idx, max_seq_k_offset - 1)) * kvcache_seqlen_stride / ELEMENT_BYTES + qk_lane_head_dim_idx;
int lds_offset = (k_warp_buffer_load_lds_offset + padding) / ELEMENT_BYTES;
BUFFER_LOAD_FUNC(k_lds, k_ptr, lds_offset, s_offset, v_offset);
}
}
}
\ No newline at end of file
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "kvcache_pv_gemm_utils_tile16x32.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_tile16x32(
vec4_uint q_addr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * ((WARP_M * kBlockK) / (32 * 32)) * 2],
int warp_id,
int query_seqlen_stride,
int max_seq_q_offset=-1) {
constexpr int Q_LOAD_REQUESTS = (kBlockM * kBlockK >> 1/*16x32 tile*/) * M_MMAC_COUNT / (4 * 32 * WARP_NUM);
constexpr int SEQUENCE_READ = M_MMAC_COUNT;
constexpr int READ_ONCE_LINES = 4;
auto BUFFER_LOAD_FUNC = &builtin_buffer_load_dword_lds<Element, float, 1>; // buffer_load_dwordx4 can also be applied if necessary
int lane_id = threadIdx.x & 63; // lane id, 0-63
int q_lane_m_idx = lane_id >> 4;
int q_lane_head_dim_idx = lane_id & 15;
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
int q_ds_read_offset = laneid_and_15 * 16 + (laneid_shfl_4 & 1) * 8 + (lane_id / 32) * 4;
int stage_id = 0;
if constexpr (STAGES > 1) {
int k_loop = 0;
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + (q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32);
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
}
if constexpr (STAGES > 1) stage_id ^= 1;
constexpr int K_LOOP_START = (STAGES > 1) ? 1: 0;
for (int k_loop = K_LOOP_START; k_loop < (kHeadDim / kBlockK); ++k_loop) {
int q_block_buffer_load_global_offset = k_loop * kBlockK;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 32);
for (int load = 0, warp_loop = warp_id; load < Q_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int q_warp_buffer_load_m_id = warp_loop & (kBlockM / READ_ONCE_LINES - 1);
int q_warp_buffer_load_lds_offset = q_lds_stage_offset + ((q_warp_buffer_load_m_id >> 3) * (32 * 32) + (q_warp_buffer_load_m_id & 7) * (READ_ONCE_LINES * 32));
int offset_s = q_block_buffer_load_global_offset / 2;
int offset_v = q_warp_buffer_load_m_id * READ_ONCE_LINES + q_lane_m_idx;
int lds_offset = q_warp_buffer_load_lds_offset / 2;
offset_v = (min(offset_v, max_seq_q_offset - 1) * query_seqlen_stride) / 2 + q_lane_head_dim_idx;
BUFFER_LOAD_FUNC(q_lds, q_addr, lds_offset, offset_s, offset_v);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (STAGES > 1) stage_id ^= 1;
q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset;
int k_loop_idx = (STAGES > 1) ? k_loop - 1: k_loop;
q_reg[k_loop_idx * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f32 = *(vec4_fp32*)(q_lds_v2fp16 + lds_offset);
}
}
}
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
}
if constexpr (STAGES > 1) {
__builtin_amdgcn_s_waitcnt(0);
stage_id ^= 1;
int q_lds_stage_offset = stage_id * (kBlockM / 32) * (kBlockK / 32) * (32 * 16);
vec2_Element<Element> *q_lds_v2fp16 = (vec2_Element<Element> *)(q_lds);
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < kBlockK / 32; ++head_dim_idx) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 32; ++m_idx) {
#pragma unroll
for (int i = 0; i < SEQUENCE_READ; ++i) {
int lds_offset = q_lds_stage_offset + head_dim_idx * kBlockM * 16 + i * 16 * 16 + q_ds_read_offset;
q_reg[((kHeadDim / kBlockK) - 1) * (WARP_M * kBlockK) / (32 * 32) * 2 + (head_dim_idx * (WARP_M / 32) + m_idx) * 2 + i].f32 = *(vec4_fp32*)(q_lds_v2fp16 + lds_offset);
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
}
template<int kBlockK, int WARP_N, typename Element, int STAGES, int WARP_NUM>
__forceinline__ __device__ void kvcache_prefetch_k_to_lds_tile16x32(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int kvcache_seqlen_stride,
int max_seq_k_offset=-1) {
// 预先计算一些表达式
int lane_id = threadIdx.x & 63; // lane id, 0-63
int laneid_shfl_4 = lane_id >> 4;
int laneid_and_15 = lane_id & 15;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // >= bmz
int qk_lane_m_idx = lane_id >> 2;
int qk_lane_head_dim_idx = (lane_id & 3) << 2;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dwordx4_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 16;
#else // zd
int qk_lane_m_idx = laneid_shfl_4;
int qk_lane_head_dim_idx = laneid_and_15;
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element, 2>;
constexpr int READ_ONCE_LINES = 4;
#endif
constexpr int k_lds_load_num = (WARP_N * kBlockK) / (READ_ONCE_LINES * 32);
constexpr int K_LOAD_REQUESTS = k_lds_load_num;
int stage_id = 0;
int k_loop = 0;
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = k_loop * kBlockK;
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32);
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES > 1) {
int k_block_buffer_load_global_offset = (k_loop + 1) * kBlockK;
int k_lds_stage_offset = (warp_id * STAGES * 2 + stage_id * 2 + 1) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
#pragma unroll
for (int load = 0; load < K_LOAD_REQUESTS; ++load) {
int k_warp_buffer_load_n_id = load & (WARP_N / READ_ONCE_LINES - 1);
int k_warp_buffer_load_lds_offset = k_lds_stage_offset + (k_warp_buffer_load_n_id >> 3) * (32 * 32) + (k_warp_buffer_load_n_id & 7) * (READ_ONCE_LINES * 32);
int lds_offset = k_warp_buffer_load_lds_offset / 2;
int offset_s = k_block_buffer_load_global_offset / 2;
int offset_v = min(k_warp_buffer_load_n_id * READ_ONCE_LINES + qk_lane_m_idx + warp_id * WARP_N, max_seq_k_offset - 1) * kvcache_seqlen_stride / 2 + qk_lane_head_dim_idx;
BUFFER_LOAD_FUNC(k_lds, k_addr, lds_offset, offset_s, offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
}
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int BLOCK_ROW_STRIDE, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_dropout(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k, const int col_idx_offset_,
unsigned long long seed, unsigned long long offset, uint32_t p_dropout_in_8bits_value,
union_vec2_uint rowcol, uint32_t* dropout_debug_count) {
// static_assert(WARP_M == 32 and "For Dropout, only WARP_M=32 is supported yet!");
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
// prepare 4 uint for 16 uint8
union_vec4_uint random_uint4;
for (int mi = 0; mi < (WARP_M / 32); ++mi, rowcol.u32.x += BLOCK_ROW_STRIDE) { // when WARP_M > 32, attention, block_row_idx is computed by BLOCK_M / 32 rather than BLOCK_M / WARP_M
#pragma unroll
for (uint32_t ni = 0; ni < (WARP_N / 32); ++ni, ++rowcol.u32.y) {
// for each 16 elements, generate 16 int8 -> 4 u32
random_uint4.u32 = flash::philox(seed, rowcol.u64, offset);
int cnt = 0;
#pragma unroll
for(uint32_t min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
if (col_idx < max_seqlen_k) {
#pragma unroll
for(uint32_t min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
uint32_t cur_pos = (min_tile_n * 2 + min_tile_m) * 4 + vec_idx;
uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since hcu has no compare instructions with 8/16 bits
if (cur_rand >= p_dropout_in_8bits_value) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = 0x0;
++cnt;
}
}
}
}
}
#if 0
atomicAdd(dropout_debug_count, cnt);
if (threadIdx.x == 0) atomicAdd(dropout_debug_count + 1, 1);
#endif
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int ngroups) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_right = std::min(max_seqlen_k, (row_idx / ngroups)/*only for layout 1: bshd*/ + max_seqlen_k - (max_seqlen_q / ngroups));
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <bool HasWSLeft=true, typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_local(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q,
const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ?
-INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_alibi(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, float g_alibi) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2;
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m;
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n;
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 8;
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] += g_alibi * (col_idx - row_idx);
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void kvcache_thread_reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = -INFINITY; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx * 2].u64 = 0x0;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary[m_idx * 2].f32[min_tile_m] = 0; // OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
} else {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64;
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
}
}
}
#else
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m];
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
}
}
}
#endif
}
}
}
template<typename Operator, typename DataType, int WARP_M>
__device__ inline void kvcache_quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); mi++) {
dst[mi] = Allreduce<64>::run(src[mi], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void kvcache_reduce_(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if constexpr (OpType == 0) { // sum
if constexpr (zero_init == true) {
kvcache_thread_reduce_sum<true, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
kvcache_thread_reduce_sum<false, Operator, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
} else if constexpr (OpType == 1) { // max
if constexpr (zero_init == true) {
kvcache_thread_reduce_max<true, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op);
kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
kvcache_thread_reduce_max<false, Operator, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, summary, op, summary_cur);
kvcache_quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void kvcache_reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if constexpr (zero_init == true) {
kvcache_reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op);
} else {
kvcache_reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
__device__ inline void kvcache_reduce_sum(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if constexpr (zero_init == true) {
kvcache_reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op);
} else {
kvcache_reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N, int M_MMAC_COUNT>
inline __device__ void kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], const DataType1 *max, const float scale) {
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const float max_scaled = (max[mi * 2].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi * 2].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
__float2 neg_max_scaled_pair = {-max_scaled, -max_scaled};
__float2 scale_pair = {scale, scale};
#pragma unroll
for (int ni = 0; ni < (WARP_N / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32 * 32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
#else
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1, typename DataType2, int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int M_MMAC_COUNT>
inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_M / 32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4], DataType2* max_lds, int warp_id, float softmax_scale_log2) {
if constexpr (Is_first) {
kvcache_reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max);
kvcache_scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, softmax_scale_log2);
kvcache_reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum);
} else {
DataType1 scores_max_cur[(WARP_M / 32)];
kvcache_reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
int lane_id = threadIdx.x & 63;
if constexpr (WARP_NUM > 1) {
int dword_offset_base = (lane_id & 15);
if (lane_id < 16) {
if (warp_id == 0) {
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
max_lds[dword_offset_base + m_loop * 32] = -INFINITY;
}
}
__syncthreads();
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
__builtin_amdgcn_ds_fmaxf((__attribute__((address_space(3))) float *)max_lds + dword_offset_base + m_loop * 32, scores_max_cur[0].f32[m_loop], 0, 0, false);
}
}
__syncthreads();
for (int m_loop = 0; m_loop < M_MMAC_COUNT; ++m_loop) {
scores_max_cur[0].f32[m_loop] = max_lds[dword_offset_base + m_loop * 32];
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi * 2].f32[min_tile_m]
: (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi * 2].f32[min_tile_m] *= scores_scale;
__float2 scores_scale_pair = {scores_scale, scores_scale};
#pragma unroll
for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); ++pv_n_loop) {
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32 * 32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
}
#else
// 928 及之前的架构没 pk_mul 指令
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].f32[vec_idx] *= scores_scale;
}
#endif
}
}
}
}
}
kvcache_scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_max_cur, softmax_scale_log2);
DataType1 scores_sum_cur[(WARP_M / 32)];
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
scores_sum_cur[mi].u64 = 0x0;
}
kvcache_reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N, M_MMAC_COUNT>(scores, scores_sum_cur);
if constexpr (WARP_NUM > 1) {
// 重新求多个 wave 的归一化和
DataType2* sum_lds = max_lds + 64;
if(lane_id < 16) {
// 每个 wave 的归一化和写到 lds
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
sum_lds[warp_id * WARP_M + mi * 32 + lane_id * 2] = scores_sum_cur[mi].f32[0];
} else {
*(__float2*)(sum_lds + warp_id * WARP_M + mi * 32 + lane_id * 2) = scores_sum_cur[mi].u64; // M_MMAC_COUNT doesn't exceed 2
}
}
__syncthreads();
// 0 号 wave reduce 其他 wave 的归一化和
if (warp_id == 0) {
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
float tmp = sum_lds[mi * 32 + lane_id * 2];
for(int warp_loop = 1; warp_loop < WARP_NUM; warp_loop++) {
tmp += sum_lds[warp_loop * WARP_M + mi * 32 + lane_id * 2];
}
sum_lds[mi * 32 + lane_id * 2] = tmp;
} else {
__float2 cur_wave_sum = *(__float2*)(sum_lds + mi * 32 + lane_id * 2);
#pragma unroll
for(int warp_loop = 1; warp_loop < WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
#endif
}
*(__float2*)(sum_lds + mi * 32 + lane_id * 2) = cur_wave_sum;
}
}
}
}
__syncthreads();
// 4 个 wave 从 lds 中读取最后 reduce 的归一化和
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
if constexpr (M_MMAC_COUNT == 1) {
scores_sum_cur[mi * 2].f32[0] = sum_lds[mi * 32 + (lane_id & 15) * 2];
} else {
scores_sum_cur[mi * 2].u64 = *(__float2*)(sum_lds + mi * 32 + (lane_id & 15) * 2);
}
}
__syncthreads(); // 以免后续的 buffer_load_to_lds 调度到这之前
}
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
// #######################################################
scores_max[mi].u64 = scores_max_cur[mi].u64;
#else
scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0];
scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1];
// #######################################################
scores_max[mi].f32[0] = scores_max_cur[mi].f32[0];
scores_max[mi].f32[1] = scores_max_cur[mi].f32[1];
#endif
}
}
};
template <int WARP_M, int WARP_N, int M_MMAC_COUNT, typename Element, typename ElementAccum>
inline __device__ void kvcache_convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (WARP_N / 32)][4], union_vec4_fp32 s_reg[(WARP_M / 32) * (WARP_N / 32)][4]) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
#pragma unroll
for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
#else
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
#endif
}
}
}
}
}
#pragma once
#include "fwd/utils.h"
using namespace flash;
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_tile16x32(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k,
const int col_idx_offset_ = 0) {
const int lane_id = threadIdx.x & 63; // lane id, 0-63
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_tile16x32(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int ngroups) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_right = std::min(max_seqlen_k, (row_idx / ngroups)/*only for layout 1: bshd*/ + max_seqlen_k - (max_seqlen_q / ngroups));
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_local_mask_causal_tile16x32(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int ngroups,
const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int col_idx_limit_left = std::max(0, (row_idx / ngroups) + 1 + max_seqlen_k - (max_seqlen_q / ngroups) - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, (row_idx / ngroups)/*only for layout 1: bshd*/ + max_seqlen_k - (max_seqlen_q / ngroups) + window_size_right);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (col_idx < col_idx_limit_left - 1)) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_tile16x32_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int mtp, const int layout) {
const int MTP_REGROUP_COUNT = max_seqlen_q / mtp;
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + lane_id & 15;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / MTP_REGROUP_COUNT);
const int col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
#pragma once
#include "numeric_types.h"
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, int Padding, typename ElementAccum>
__forceinline__ __device__ void fp8_mla_acco_reduce_tile16x32(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
#if 1
constexpr int PREFETCH = WARP_NUM;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += PREFETCH) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int prefetch = 0; prefetch < PREFETCH; ++prefetch) {
vec4_fp32 f32x4 = acc_o[k_loop + prefetch][min_tile_n * 2].f32;
int lds_write_offset = warp_id * 2048 + prefetch * 2 * 16 * 16 + min_tile_n * 16 * 16;
// __builtin_hcu_ds_write_matrix_format_f32(f32x4, acc_o_lds + lds_write_offset, 0, 1, 1, 0, 0);
int lds_address = reinterpret_cast<size_t>(acc_o_lds + lds_write_offset);
asm volatile(
"s_mov_b32 m0, %0\n\t"
"s_nop 0\n\t"
"ds_write_matrix_format %1, m0 element:3 row:1 col:1\n"
:: "s"(lds_address), "v"(f32x4)
:);
}
}
union_vec4_fp32 data[2][WARP_NUM];
// 4 * 2 bursts of ds_read2_b32
{
constexpr int min_tile_n = 0;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH);
#pragma unroll
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
inline_ds_read2_b32_no_wait(acc_o_lds, neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id, data[min_tile_n][neighbor].u64[0], 64);
inline_ds_read2_b32_no_wait(acc_o_lds, neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 2 * 64, data[min_tile_n][neighbor].u64[1], 64);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
// 4 * 2 bursts of ds_read2_b32
constexpr int ds_bursts = WARP_NUM * 2;
{
constexpr int min_tile_n = 1;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH + ds_bursts);
#pragma unroll
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
inline_ds_read2_b32_no_wait(acc_o_lds, neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id, data[min_tile_n][neighbor].u64[0], 64);
inline_ds_read2_b32_no_wait(acc_o_lds, neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 2 * 64, data[min_tile_n][neighbor].u64[1], 64);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
// wait burst1 data arrived
#pragma unroll
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
constexpr int min_tile_n = 0;
flash::wait_lds_data_arrived<false>(ds_bursts - 2 - neighbor * 2 + ds_bursts);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
// wait burst2 data arrived
#pragma unroll
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
constexpr int min_tile_n = 1;
flash::wait_lds_data_arrived<false>(ds_bursts - 2 - neighbor * 2);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
// sync
flash::wait_all_warp_arrived();
}
#elif 1
// #######################################################################################################################################
// default path
// #######################################################################################################################################
constexpr int PREFETCH = 4;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += PREFETCH) {
#pragma unroll
for (int prefetch = 0; prefetch < PREFETCH; ++prefetch) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
vec4_fp32 f32x4 = acc_o[k_loop + prefetch][min_tile_n * 2].f32;
int lds_write_offset = warp_id * 2048 + prefetch * 2 * 16 * 16 + min_tile_n * 16 * 16;
__builtin_hcu_ds_write_matrix_format_f32(f32x4, acc_o_lds + lds_write_offset, 0, 1, 1, 0, 0);
}
}
__syncthreads();
acc_o[k_loop + 0][0].b64[0] = 0x0;
acc_o[k_loop + 0][0].b64[1] = 0x0;
acc_o[k_loop + 0][2].b64[0] = 0x0;
acc_o[k_loop + 0][2].b64[1] = 0x0;
#pragma unroll
for (int neighbor = 0; neighbor < 4; ++neighbor) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
union_vec4_fp32 data;
data.f32[0] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id]; // ds_read2st64_b32 generated for acceleration
data.f32[1] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 1 * 64];
data.f32[2] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 2 * 64];
data.f32[3] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 3 * 64];
acc_o[k_loop + 0][min_tile_n * 2].u64[0] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], data.u64[0]);
acc_o[k_loop + 0][min_tile_n * 2].u64[1] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], data.u64[1]);
}
}
__syncthreads();
}
#else
// #######################################################################################################################################
// bank-conflicts free path, but lower performance
// #######################################################################################################################################
constexpr int PREFETCH = WARP_NUM;
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += PREFETCH) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int prefetch = 0; prefetch < PREFETCH; ++prefetch) {
vec4_fp32 f32x4 = acc_o[k_loop + prefetch][min_tile_n * 2].f32;
int lds_write_offset = warp_id * 2048 + prefetch * 2 * 16 * 16 + min_tile_n * 16 * 16;
lds_write_offset = reinterpret_cast<size_t>(acc_o_lds + lds_write_offset + lane_id * 4);
inlineasm_ds_write_b128(lds_write_offset, f32x4);
}
}
union_vec4_fp32 data[2][WARP_NUM];
constexpr int ds_bursts = PREFETCH;
{
constexpr int min_tile_n = 0;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 1;
flash::wait_lds_data_arrived<true>((1 - min_tile_n) * PREFETCH + ds_bursts);
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
int lds_read_offset = reinterpret_cast<size_t>(acc_o_lds + neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id * 4);
inlineasm_ds_read_b128(lds_read_offset, data[min_tile_n][neighbor].f32);
}
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[0]);
inline_vgpr2_init_zero(acc_o[k_loop + 0][min_tile_n * 2].b64[1]);
}
{
constexpr int min_tile_n = 0;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor + ds_bursts);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
{
constexpr int min_tile_n = 1;
#pragma unroll
for (int neighbor = 0; neighbor < PREFETCH; ++neighbor) {
flash::wait_lds_data_arrived<false>(ds_bursts - 1 - neighbor);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], acc_o[k_loop + 0][min_tile_n * 2].u64[0], data[min_tile_n][neighbor].u64[0]);
inline_v_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], acc_o[k_loop + 0][min_tile_n * 2].u64[1], data[min_tile_n][neighbor].u64[1]);
}
}
flash::wait_all_warp_arrived();
}
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment