Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
#pragma once
#include "intrinsic_mls_ds.h"
#define USE_MLS_128B_REQUEST
template<bool Is_even_MN, int kHeadDim, int WARP_M, typename Element>
__forceinline__ __device__ void fp8_prefetch_q_to_lds(
Element* q_ptr,
int8_t* q_lds,
int warp_id,
int q_row_stride,
int max_seq_q_offset
) {
// 准备 MLS 寄存器, 填充 stride
vec4_uint q_root = prepare_for_matrix_load<128, Element>(q_ptr);
vec4_uint q_srsrc;
q_srsrc[0] = q_root[0];
q_srsrc[1] = q_root[1];
q_srsrc[2] = q_row_stride; // stride
q_srsrc[3] = 0x40000; // [17: 18], interleave 4
// 计算 lds 写入地址
int q_lds_offset = warp_id * WARP_M * kHeadDim/*4K bytes, 16K bytes in total*/ * sizeof(Element);
int q_lds_write_bytes = reinterpret_cast<size_t>(q_lds) + q_lds_offset;
// 计算 global 读取地址
q_srsrc[0] = q_root[0] + (warp_id * WARP_M) * q_row_stride * sizeof(Element);
//边界判断
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = q_srsrc[3] + max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8; // set only once
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8)); // set only once
// printf("nm_filter is %d, max_seq_q_osffset is %d\n", max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8, max_seq_q_offset);
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
// inline_matrix_load_128x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_offset/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_q_offset);
q_srsrc[3] = 0x40000 + ((max_seq_q_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(q_srsrc, q_lds+q_lds_offset+512, 16, true, false, false, false, false);
#else
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_write_bytes/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
q_srsrc[0] = q_srsrc[0] + 64 * sizeof(Element);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_write_bytes + 2048/*lds bytes*/, 0/*matrix_offset, 0 or 16*/); // Q 部分可以考虑 128x16 或者非 4-interleave 形式
#endif
__builtin_amdgcn_sched_barrier(0);
}
// #define USE_DS_READ_B128_FOR_INTERLEAVE4
template<int kHeadDim, int WARP_M, typename Element>
__forceinline__ __device__ void load_q_from_lds_to_vgpr(
union_vec16_fp8 q_regs[WARP_M / 16][kHeadDim / 64],
int8_t* q_lds,
int warp_id,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
// lds 写到两个地方去了, 注意是 rearrange, 所以跳 1K; transpose 跳 2K
// MLS0: [0: 512) 和 [1024, 1536)
// MLS1: [512: 1024) 和 [1536, 2048)
// 分 4 次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int row = (lane_id & 15) >> 1;
int col = lane_id >> 4;
int col_swizzle = (row + col) & 3;
int lds_load_offset = row * 128 + col_swizzle * 16 + (lane_id & 1) * 64 + warp_id * WARP_M * kHeadDim;
q_regs[0][0].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 0);
q_regs[1][0].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 1024/*ds fmt 0, dmft1 */);
q_regs[0][1].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 2048/*ds fmt 0, dmft1 */);
q_regs[1][1].i32x4 = *(vec4_int32*)(q_lds + lds_load_offset + 3072/*ds fmt 0, dmft1 */);
#else
q_regs[0][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + 0 + warp_id * WARP_M * kHeadDim, 0, 3, 1, 0);
q_regs[1][0].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + 1024 + warp_id * WARP_M * kHeadDim, 0, 3, 1, 0);
q_regs[0][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + 2048 + warp_id * WARP_M * kHeadDim, 0, 3, 1, 0);
q_regs[1][1].i32x4 = __builtin_hcu_ds_read_matrix_trans_format_u8(q_lds + 3072 + warp_id * WARP_M * kHeadDim, 0, 3, 1, 0);
#endif
__builtin_amdgcn_sched_barrier(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
template<bool Is_even_MN, int kHeadDim, int WARP_N, typename Element>
__forceinline__ __device__ void fp8_prefetch_k_to_lds(
Element* k_ptr,
int8_t* k_lds,
int warp_id,
int k_row_stride,
int max_seq_kv_offset
) {
// 准备 MLS 寄存器, 填充 stride
vec4_uint k_root = prepare_for_matrix_load<kHeadDim, Element>(k_ptr);
vec4_uint k_srsrc;
k_srsrc[0] = k_root[0];
k_srsrc[1] = k_root[1];
k_srsrc[2] = k_row_stride; // stride
k_srsrc[3] = 0x40000; // [17: 18], interleave 4
// 计算 lds 写入地址
int k_lds_offset = warp_id * WARP_N * kHeadDim * sizeof(Element);
int k_lds_write_bytes = reinterpret_cast<size_t>(k_lds) + k_lds_offset;
// 计算 global 读取地址
k_srsrc[0] = k_root[0] + warp_id * 32 * k_row_stride * sizeof(Element);
//边界判断
int nm_filter = inline_min_max<0, 16>(32 * warp_id + 16 - max_seq_kv_offset);
k_srsrc[3] = k_srsrc[3] + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
// 同步所有warp,确保srsrc参数准备完毕后再发起MLS load
flash::wait_all_warp_arrived();
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset, 0, true, false, false, false, false);
nm_filter = inline_min_max<0, 16>(32 * warp_id + 32 - max_seq_kv_offset);
k_srsrc[3] = 0x40000 + ((max_seq_kv_offset % 128 == 0) ? 0: (nm_filter << 8));
__builtin_hcu_matrix_load_128X16_b8(k_srsrc, k_lds+k_lds_offset+512, 16, true, false, false, false, false);
#else
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, k_lds_write_bytes/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
k_srsrc[0] = k_srsrc[0] + 64 * sizeof(Element);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, k_lds_write_bytes + 2048/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
#endif
}
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include "../utils.h"
using namespace flash;
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int max_seq_kv_offset,
int wave_col_offset,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] =
(n_base + k_offset + vec_idx >= max_seq_kv_offset)
? -INFINITY
: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_causal_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int actual_seqlen_q,
int actual_seqlen_k,
int wave_row_offset,
int wave_col_offset,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int row_base = wave_row_offset + ((lane_id & 15) >> 2) * 8 + (lane_id & 3);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
const int causal_limit = actual_seqlen_k - actual_seqlen_q;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
const int row_idx = row_base + m_idx * 4;
const int col_limit = min(actual_seqlen_k, row_idx + causal_limit);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] = (n_base + k_offset + vec_idx > col_limit) ? -INFINITY: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_apply_local_mask(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
int actual_seqlen_q,
int actual_seqlen_k,
int wave_row_offset,
int wave_col_offset,
int window_size_left,
int window_size_right,
int lane_id
) {
__builtin_amdgcn_sched_barrier(0);
const int row_base = wave_row_offset + ((lane_id & 15) >> 2) * 8 + (lane_id & 3);
const int col_base = wave_col_offset + (lane_id >> 4) * 8;
const bool has_ws_left = window_size_left >= 0;
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
const int row_idx = row_base + m_idx * 4;
const int col_limit_left = max(0, row_idx + 1 + actual_seqlen_k - actual_seqlen_q - window_size_left);
const int col_limit_right = min(actual_seqlen_k, row_idx + actual_seqlen_k - actual_seqlen_q + window_size_right);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
const int k_offset = k_loop * WARP_N;
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
const int n_base = col_base + n_idx * 4;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = n_base + k_offset + vec_idx;
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] =
(col_idx > col_limit_right || (has_ws_left && col_idx < col_limit_left - 1))
? -INFINITY
: s_reg[k_loop][m_idx][n_idx].f32[vec_idx];
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kBlockN, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_qk_descale(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
__float2 qk_descale
) {
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = __builtin_hcu_pk_mul_f32(s_reg[k_loop][m_idx][n_idx].u64[vec_idx], qk_descale);
// s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = s_reg[k_loop][m_idx][n_idx].u64[vec_idx] * qk_descale;
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<bool AssumeValidRows, int kHeadDim, int kBlockN, int WARP_M, int WARP_N, int WARP_K, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_softmax_and_schedule_v(
/*softmax module related args*/
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16],
ElementAccum softmax_scale_log2,
/*scheduled modules related args*/
union_vec16_fp8 v_regs[kBlockN / WARP_N][kHeadDim / 32],
int8_t* v_lds
) {
// ======================================================== Max ======================================================================
ElementAccum scores_max_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
ElementAccum max_value = scores_max[m_idx];
// 当前线程遍历 4 个 32x32x32 mmac 输出的 f32x4
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
max_value = max(max_value, s_reg[k_loop][m_idx][n_idx].f32[vec_idx]);
}
}
}
// 这一行比较 0, 16, 32, 48 号线程的数据
max_value = max(max_value, __shfl_xor_tmp(max_value, 32));
max_value = max(max_value, __shfl_xor_tmp(max_value, 16));
// 赋值给最终的最大值
scores_max_cur[m_idx] = max_value;
}
// ========================================== softmax ===============================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
__float2 max_scaled_pair;
if constexpr (AssumeValidRows) {
max_scaled_pair[0] = -scores_max_cur[m_idx] * softmax_scale_log2;
} else {
max_scaled_pair[0] = scores_max_cur[m_idx] == -INFINITY ? 0.f: -scores_max_cur[m_idx] * softmax_scale_log2;
}
max_scaled_pair[1] = max_scaled_pair[0];
__float2 softmax_scale_log2_pair = {softmax_scale_log2, softmax_scale_log2};
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = __builtin_hcu_pk_fma_f32(s_reg[k_loop][m_idx][n_idx].u64[vec_idx], softmax_scale_log2_pair, max_scaled_pair);
}
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
s_reg[k_loop][m_idx][n_idx].f32[vec_idx] = __llvm_exp2_f32(s_reg[k_loop][m_idx][n_idx].f32[vec_idx]);
}
}
}
}
// ========================================== Sum ===============================================
ElementAccum scores_sum_cur[WARP_M / 16];
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
vec2_Accum<ElementAccum> sum_pair;
sum_pair.data = 0.0;
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
sum_pair.u64 = __builtin_hcu_pk_add_f32(sum_pair.u64, s_reg[k_loop][m_idx][n_idx].u64[0]);
sum_pair.u64 = __builtin_hcu_pk_add_f32(sum_pair.u64, s_reg[k_loop][m_idx][n_idx].u64[1]);
}
}
scores_sum_cur[m_idx] = sum_pair.f32[0] + sum_pair.f32[1];
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor_tmp(scores_sum_cur[m_idx], 32);
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor_tmp(scores_sum_cur[m_idx], 16);
}
// 更新 scores_sum, scores_max
// 这段代码放在这是因为即将下发的大量 ds 指令, 会跟 __shfl_xor 抢带宽, 导致时延太高
// ElementAccum exp_rescale[WARP_M / 16];
// #pragma unroll
// for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
// exp_rescale[m_idx] = __llvm_exp2_f32((scores_max[m_idx] - scores_max_cur[m_idx]) * softmax_scale_log2);
// scores_max[m_idx] = scores_max_cur[m_idx];
// scores_sum[m_idx] = __llvm_fma_f32(scores_sum[m_idx], exp_rescale[m_idx], scores_sum_cur[m_idx]);
// }
// ========================================== schedule V ===============================================
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0)\n\ts_barrier\n");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; k_loop += 1) {
// 用 ds_read_matrix 从 lds 读取数据到寄存器
int8_t* lds_load_ptr = v_lds + k_loop * WARP_M * kHeadDim * sizeof(Element);
v_regs[k_loop][0].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr, 0, 2, 2, 0);
v_regs[k_loop][1].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 32, 0, 2, 2, 0);
v_regs[k_loop][2].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 128 * 16, 0, 2, 2, 0);
v_regs[k_loop][3].i32x4 = __builtin_hcu_ds_read_matrix_format_u8(lds_load_ptr + 128 * 16 + 32, 0, 2, 2, 0);
}
__builtin_amdgcn_sched_barrier(0); // hint: 这里考虑只发一部分的 ds_read_matrix 指令出去, 一面堵住
// ========================================== rescale ===============================================
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; m_idx += 1) {
if (scores_sum[m_idx] != 0.f && scores_max[m_idx] < scores_max_cur[m_idx]) {
__float2 scores_scale_pair;
float max_diff;
if constexpr (AssumeValidRows) {
max_diff = scores_max[m_idx] - scores_max_cur[m_idx];
} else {
// Fix: 当 scores_max 和 scores_max_cur 都是 -INFINITY 时,(-INF) - (-INF) = NaN
// 这种情况发生在某些 query 行完全没有有效的 KV 可以 attend 时
max_diff = (scores_max[m_idx] == -INFINITY || scores_max_cur[m_idx] == -INFINITY)
? 0.f : (scores_max[m_idx] - scores_max_cur[m_idx]);
}
scores_scale_pair[0] = __llvm_exp2_f32(max_diff * softmax_scale_log2);
scores_scale_pair[1] = scores_scale_pair[0];
scores_sum[m_idx] *= scores_scale_pair[0];
// 放缩 acc_o
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / WARP_N; ++pv_loop) {
#pragma unroll
for (int mmac_id = 0; mmac_id < WARP_K / 16; ++mmac_id) {
acc_o[pv_loop][m_idx][mmac_id].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[pv_loop][m_idx][mmac_id].u64[0], scores_scale_pair);
acc_o[pv_loop][m_idx][mmac_id].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[pv_loop][m_idx][mmac_id].u64[1], scores_scale_pair);
}
}
}
}
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max[m_idx] = scores_max_cur[m_idx];
scores_sum[m_idx] += scores_sum_cur[m_idx];
}
}
template<int kBlockN, int WARP_M, int WARP_N, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_cvt_f32_to_fp8(
vec4_Accum<ElementAccum> s_reg[kBlockN / WARP_N][WARP_M / 16][WARP_N / 16],
union_vec32_fp8 p_reg[WARP_M / 16]
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int k_loop = 0; k_loop < kBlockN / WARP_N; ++k_loop) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[k_loop][m_idx][n_idx].f32, p_reg[m_idx].i32[k_loop * 2 + n_idx]);
}
}
}
}
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
template<int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, int TailTile16, bool Is_even_MN, bool Is_Interleaved, bool TcpSwizzle, typename Element, typename ElementAccum> template<int kHeadDimV, int kBlockM, int kBlockK, int WARP_M, bool Is_even_MN, bool Is_Interleaved, bool TcpSwizzle, typename Element, typename ElementAccum>
__forceinline__ __device__ void fwd_epilogue_store_output_gfx938( __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
Element *o_ptr, Element *o_ptr,
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4], vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4],
...@@ -16,9 +16,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938( ...@@ -16,9 +16,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
int pv_lane_seq_idx = lane_id & 15; int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4; int pv_lane_head_dim_idx = lane_id >> 4;
static_assert (Is_Interleaved and "For fwd_epilogue_store_output_gfx938, mmac must be 4interleave"); if constexpr (Is_Interleaved) {
if constexpr (TailTile16 == 2) {
#pragma unroll #pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) { for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll #pragma unroll
...@@ -42,44 +40,14 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938( ...@@ -42,44 +40,14 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
} }
} else { } else {
*(vec4_fp32*)(o_ptr + v_offset + s_offset) = v_data.f32; *(vec4_fp32*)(o_ptr + v_offset + s_offset) = v_data.f32;
// auto o_resource = prepare_for_buffer_load<kHeadDimV, Element, true>(o_ptr);
// inline_buffer_store_dwordx4<vec4_fp32, 1>(v_data.f32, v_offset, o_resource, s_offset, /* immediate integer */0);
} }
} }
} }
} }
} } // brace, to control vgpr usage
} else { } else {
#pragma unroll
for (int k_loop = 0; k_loop < (kHeadDimV / kBlockK); ++k_loop) {
#pragma unroll
for (int warp_m_idx = 0; warp_m_idx < (WARP_M / 32); ++warp_m_idx) {
#pragma unroll
for (int k_tile_idx = 0; k_tile_idx < (kBlockK / 32); ++k_tile_idx) {
#pragma unroll 2
for (int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) {
#pragma unroll 2
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int pv_tile_id = k_loop * (WARP_M / 32) * (kBlockK / 32) + warp_m_idx * (kBlockK / 32) + k_tile_idx;
const int mmac_id = min_tile_m + min_tile_n * 2;
int seqlen_q_offset = warp_id * WARP_M + warp_m_idx * 32 + min_tile_m * 16 + pv_lane_seq_idx;
int s_offset = k_tile_idx * 32 + min_tile_n * 16;
int v_offset = seqlen_q_offset * seqlen_o_stride + k_loop * kBlockK + pv_lane_head_dim_idx * 4;
union_vec2_f16x2<Element> v_data;
#pragma unroll
for (int vec_index = 0; vec_index < 2; ++vec_index) {
v_data.f16x2[vec_index] = DownCastPair<ElementAccum, Element>(acc_o[pv_tile_id][mmac_id].f32x2[vec_index]);
}
if constexpr (not Is_even_MN) {
if (m_block * kBlockM + seqlen_q_offset < seqlen_q_limit) {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
} else {
*(union_vec2_f16x2<Element>*)(o_ptr + v_offset + s_offset) = v_data;
}
}
}
}
}
}
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
\ No newline at end of file
...@@ -48,21 +48,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -48,21 +48,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
// MLS // MLS
if constexpr (Is_even_MN) { // Is_even_MN 简单场景下 nm_filter 场景简化, 非 BlockM = 128 场景未必全支持 if constexpr (Is_even_MN) { // Is_even_MN 简单场景下 nm_filter 场景简化, 非 BlockM = 128 场景未必全支持
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES);
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] = 0x20000;
} else { } else {
int nm_filter_max = n_loop * WARP_K + 32 - max_seq_kv_offset; int nm_filter_max = n_loop * WARP_K + 32 - max_seq_kv_offset;
int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop; // 如果全越界了, 则只访问 n_loop = 0 的那波数据 int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
v_srsrc[3] = nm_filter << 8; v_srsrc[3] = nm_filter << 8;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32) * ELEMENT_BYTES);
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] += 0x20000;
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} else if constexpr (kHeadDimV == 192) { } else if constexpr (kHeadDimV == 192) {
int warp_id_m = warp_id % 2; // w0 w2 int warp_id_m = warp_id % 2; // w0 w2
int warp_id_n = warp_id / 2; // w1 w3 int warp_id_n = warp_id / 2; // w1 w3
...@@ -72,14 +69,11 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -72,14 +69,11 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
if constexpr (true) { if constexpr (true) {
int nm_filter = inline_min_max<0, 16>(n_loop_ * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset); // 重新计算 nm_filter int nm_filter = inline_min_max<0, 16>(n_loop_ * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset); // 重新计算 nm_filter
v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8; v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] += 0x20000;
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x16_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 16, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
// DS // DS
lds_stage_id ^= 1; lds_stage_id ^= 1;
...@@ -87,11 +81,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -87,11 +81,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS); flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES; int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
if constexpr (TailTile16 == 2) {
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/); DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
} else {
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
}
stage_id ^= 1; stage_id ^= 1;
for (int k_loop = 1; k_loop < (computeHeadDim / kBlockN); ++k_loop) { for (int k_loop = 1; k_loop < (computeHeadDim / kBlockN); ++k_loop) {
...@@ -104,11 +94,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -104,11 +94,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
} }
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES; int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
if constexpr (TailTile16 == 2) {
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/); DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
} else {
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
}
flash::wait_lds_data_arrived<false>(3); flash::wait_lds_data_arrived<false>(3);
// MMAC // MMAC
...@@ -154,21 +140,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -154,21 +140,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
int n_loop_ = ((kHeadDimV / kBlockN) - k_loop) < 4 ? (n_load = 0, n_loop): n_loop_; // if finish kHeadDimV*WarpK prefetch, we prefetch next n_loop data int n_loop_ = ((kHeadDimV / kBlockN) - k_loop) < 4 ? (n_load = 0, n_loop): n_loop_; // if finish kHeadDimV*WarpK prefetch, we prefetch next n_loop data
if constexpr (Is_even_MN) { if constexpr (Is_even_MN) {
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * 32 * WARP_NUM) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * 32 * WARP_NUM) * ELEMENT_BYTES);
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] = 0x20000;
} else { } else {
int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset; int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据 int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
v_srsrc[3] = nm_filter << 8; v_srsrc[3] = nm_filter << 8;
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] += 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * 32 * WARP_NUM) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * 32 * WARP_NUM) * ELEMENT_BYTES);
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
} }
} }
...@@ -230,21 +213,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -230,21 +213,18 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
if constexpr (Is_even_MN) { if constexpr (Is_even_MN) {
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id_m * 16 * seqlen_v_stride + warp_id_n * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id_m * 16 * seqlen_v_stride + warp_id_n * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] = 0x20000;
} else { } else {
int nm_filter_max = n_loop_ * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset; int nm_filter_max = n_loop_ * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset;
int real_mls_loop = nm_filter_max >= 16 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据 int real_mls_loop = nm_filter_max >= 16 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
int nm_filter = inline_min_max<0, 16>(real_mls_loop * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset); // 重新计算 nm_filter int nm_filter = inline_min_max<0, 16>(real_mls_loop * WARP_K + warp_id_m * 16 + 16 - max_seq_kv_offset); // 重新计算 nm_filter
v_srsrc[3] = nm_filter << 8; v_srsrc[3] = nm_filter << 8;
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] += 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id_m * 16 * seqlen_v_stride + warp_id_n * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id_m * 16 * seqlen_v_stride + warp_id_n * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x16_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 16, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
lds_stage_id ^= 1; lds_stage_id ^= 1;
...@@ -263,11 +243,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -263,11 +243,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
// DS // DS
int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES; int lds_load_offset = v_lds_base + (0/*k_loop*/ * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
if constexpr (TailTile16 == 2) {
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/); DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
} else {
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
}
stage_id ^= 1; stage_id ^= 1;
for (int k_loop = 1; k_loop < (computeHeadDim / kBlockN); ++k_loop) { for (int k_loop = 1; k_loop < (computeHeadDim / kBlockN); ++k_loop) {
...@@ -281,11 +257,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds( ...@@ -281,11 +257,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
// DS // DS
int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES; int lds_load_offset = v_lds_base + (k_loop * 32 * 32 + lds_stage_id * WARP_K * kHeadDimV_OPT) * ELEMENT_BYTES;
if constexpr (TailTile16 == 2) {
DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/); DS_READ_MATRIX_32X32_B16_ALT2(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
} else {
DS_READ_MATRIX_32X32_B16(lds_load_offset, v_reg[stage_id * 2 + 0].f16, v_reg[stage_id * 2 + 1].f16, false/*transpose*/);
}
flash::wait_lds_data_arrived<false>(3); flash::wait_lds_data_arrived<false>(3);
// MMAC // MMAC
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "intrinsic_mls_ds.h" #include "intrinsic_mls_ds.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int TailTile16, typename Element, bool Is_even_MN> template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, typename Element, bool Is_even_MN>
__forceinline__ __device__ void prefetch_v_to_lds_mls_ds( __forceinline__ __device__ void prefetch_v_to_lds_mls_ds(
vec4_uint v_ptr, vec4_uint v_ptr,
Element* v_lds, Element* v_lds,
...@@ -28,15 +28,13 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds( ...@@ -28,15 +28,13 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds(
if constexpr (true) { if constexpr (true) {
int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset); int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8; v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
if constexpr (TailTile16 == 2) { v_srsrc[3] += 0x20000; } v_srsrc[3] += 0x20000;
} }
int lds_stage_id = 0; int lds_stage_id = 0;
int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 防止写 v lds 和读 k lds 冲突, qk 可能有的 warp 没结束 flash::wait_all_warp_arrived(); // 防止写 v lds 和读 k lds 冲突, qk 可能有的 warp 没结束
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -79,10 +79,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds( ...@@ -79,10 +79,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
} }
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} else if constexpr (kHeadDim == 192) { } else if constexpr (kHeadDim == 192) {
int warp_id_m = warp_id / 2; int warp_id_m = warp_id / 2;
int warp_id_n = warp_id % 2; int warp_id_n = warp_id % 2;
...@@ -95,10 +92,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds( ...@@ -95,10 +92,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
} }
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
// Wait MLS // Wait MLS
...@@ -178,10 +172,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds( ...@@ -178,10 +172,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
} }
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
} }
} }
...@@ -242,10 +233,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds( ...@@ -242,10 +233,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
} }
int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (n_stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
// Wait MLS // Wait MLS
...@@ -362,10 +350,11 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds( ...@@ -362,10 +350,11 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
if constexpr (STAGES == 2) { if constexpr (STAGES == 2) {
#if defined(__gfx938__) // 有的 prefetch v 写到了 mha 主 kernel 代码里 #if defined(__gfx938__) // 有的 prefetch v 写到了 mha 主 kernel 代码里
prefetch_v_to_lds_mls_ds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, TailTile16, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset); prefetch_v_to_lds_mls_ds<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, max_seq_k_offset);
#else #else
#endif #endif
} }
} // qk_gemm } // qk_gemm
...@@ -35,10 +35,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds( ...@@ -35,10 +35,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once
} }
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
} }
stage_id ^= 1; stage_id ^= 1;
...@@ -50,10 +47,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds( ...@@ -50,10 +47,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; // set only once
} }
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
stage_id ^= 1; stage_id ^= 1;
// DS // DS
...@@ -114,9 +108,7 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds( ...@@ -114,9 +108,7 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds(
} }
int lds_offset = (stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (stage_id * WARP_N * kHeadDim_ + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
...@@ -174,7 +174,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) * ...@@ -174,7 +174,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll #pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) { for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32( acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx], acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair scores_scale_pair
); );
...@@ -201,7 +201,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) * ...@@ -201,7 +201,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
for (int mi = 0; mi < (WARP_M / 32); ++mi) { for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32( scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64, scores_sum[mi].u64,
scores_sum_cur[mi].u64 scores_sum_cur[mi].u64
); );
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment