Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
......@@ -236,7 +236,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -270,7 +270,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -374,13 +374,12 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N
int qk_tile_id = mi + ni * (WARP_M / 32);
#if defined(__gfx936__) || defined(__gfx938__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
}
......@@ -397,7 +396,7 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N
template<bool Is_first, bool Check_inf, typename DataType0, typename DataType1, int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N, bool IsInference=true>
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1, int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N>
inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_M / 32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4], float softmax_scale_log2) {
if constexpr (Is_first) {
......@@ -418,7 +417,6 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
? scores_max_cur[mi * 2].f32[min_tile_m]
: (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]);
if (IsInference or scores_max[mi * 2].f32[min_tile_m] < scores_max_cur_reg) {
float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi * 2].f32[min_tile_m] *= scores_scale;
......@@ -428,13 +426,17 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK / 32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + mi + ni * (WARP_M / 32);
int mmac_id = min_tile_n * 2 + min_tile_m;
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair
);
......@@ -450,7 +452,6 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
}
}
}
}
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max_cur, softmax_scale_log2);
DataType1 scores_sum_cur[(WARP_M / 32)];
......@@ -461,7 +462,7 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......@@ -486,7 +487,7 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
template <int WARP_M, int WARP_N, typename Element, typename ElementAccum, bool IsInference=false>
template <int WARP_M, int WARP_N, typename Element, typename ElementAccum>
inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (WARP_N / 32)][4], union_vec4_fp32 s_reg[(WARP_M / 32) * (WARP_N / 32)][4]) {
#pragma unroll
for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) {
......@@ -497,25 +498,10 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#if defined(__gfx938__)
if constexpr (IsInference) {
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPairNoPack<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0],
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]
);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPairNoPack<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0],
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]
);
} else {
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]);
}
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(
s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]);
#else
p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast<float, Element, false>(
s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]);
......
......@@ -112,10 +112,10 @@ struct Allreduce {
#if defined(__gfx936__) || defined(__gfx938__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = hcu_pk_add_f32(x.u64, res.u64);
x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_swap16(x.f32[0]); // __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_swap16(x.f32[1]); // __shfl_xor_tmp(x.f32[1], 16);
res.u64 = hcu_pk_add_f32(res.u64, x.u64);
res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
......@@ -141,10 +141,7 @@ struct Allreduce {
template<const int kHeadDim, typename T, bool Do_CacheSwizzle=true>
__device__ __forceinline__ vec4_uint prepare_for_buffer_load(T* ptr) {
vec4_uint res;
struct { uint32_t lo, hi; } parts;
*(uint64_t*)&parts = reinterpret_cast<uint64_t>(ptr);
res[0] = __builtin_amdgcn_readfirstlane(parts.lo);
res[1] = __builtin_amdgcn_readfirstlane(parts.hi);
*(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
if constexpr (Do_CacheSwizzle) {
if constexpr (kHeadDim == 128) {
res[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
......@@ -213,4 +210,28 @@ __forceinline__ __device__ void attention_initialize(
}
template<int kHeadDim, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_attention_initialize(
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16]
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max[m_idx] = -INFINITY;
scores_sum[m_idx] = 0;
}
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
inline_vgpr4_init_zero(acc_o[pv_loop][m_idx][n_idx]);
}
}
}
}
} // namespace flash
......@@ -9,22 +9,7 @@
// #define USE_BUFFER_LOAD_DWORDX2
#endif
// DTK ds_read_matrix builtins (DS_READ_MATRIX_FORMAT / _TRANS_FORMAT): para1 is LDS base
// typed per element kind — e.g. *_f16 → half*3, *_bf16 / *_u16 / *_i16 → short*3, *_f32 → float*3,
// 4/8-bit and tf32/u32/i32 variants → int*3 (vendor builtin table).
// HIP may use __half for fp16 LDS while builtins expect __fp16*3; use f16 helper below.
// Probe: FA_PROBE_FAMILY_DS (lds_f16_as3, lds_bf16_as3).
template<typename T>
__forceinline__ __device__ __attribute__((address_space(3))) __fp16 *
hcu_ds_read_matrix_f16_lds_base(T *const p) {
return (__attribute__((address_space(3))) __fp16 *)(p);
}
template<typename T>
__forceinline__ __device__ __attribute__((address_space(3))) short *
hcu_ds_read_matrix_bf16_lds_base(T *const p) {
return (__attribute__((address_space(3))) short *)(p);
}
template<class DataType>
__forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resource) {
......@@ -44,177 +29,99 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx2_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void safe_inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &offset_s, const int &offset_v) {
int lds_addr_per_wave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int __offset_s = offset_s << shfl_count;
int __offset_v = offset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_nop 3\n\t"
"s_mov_b32 m0, %1\n\t"
"s_nop 0\n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds\n"
:: "v"(__offset_v), "s"(lds_addr_per_wave), "s"(scalar_rsrc), "s"(__offset_s)
:: "v"(__offset_v), "s"(lds_addr_per_wave), "s"(global_addr), "s"(__offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_glc_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l1_glc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l2_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane(
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s)
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<typename src_type=half_t, typename dst_type=float, const int dword_count=1, const int auxilariy=0>
__forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
#if defined(__gfx936__) || defined(__gfx938__)
static_assert(dword_count == 1 || dword_count == 2 || dword_count == 4, "unsupported buffer_load_dword LDS width");
// DTK currently accepts the mature asm buffer_load_* -> lds shape more reliably than
// the raw_buffer_load_lds wrapper instantiated through generic LDS pointers.
if constexpr (auxilariy == 0) {
if constexpr (dword_count == 1) {
inline_buffer_load_dword_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else if constexpr (dword_count == 2) {
inline_buffer_load_dwordx2_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else {
inline_buffer_load_dwordx4_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
}
} else if constexpr (auxilariy == 11 && dword_count == 1) {
inline_buffer_load_dword_lds_bypass_glc_slc<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else {
constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element);
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
dword_count * 4,
gvOffset_v * bytes_per_element,
gvOffset_s * bytes_per_element,
0, /* immediate offset, instruction offset */
auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element);
dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
......@@ -224,16 +131,12 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const sh
0, /* immediate offset, instruction offset */
auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#endif
}
template<typename src_type=half_t, typename dst_type=float>
__forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
#if defined(__gfx936__) || defined(__gfx938__)
inline_buffer_load_dword_lds_bypass_glc_slc<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
#else
constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element);
dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
......@@ -243,7 +146,6 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src
0, /* immediate offset, instruction offset */
11 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#endif
}
template<class DataType, const int shfl_count>
......@@ -570,43 +472,6 @@ inline __HOST_DEVICE__ unsigned short inlineasm_float2bfloat16_ushort_nonan(cons
// DTK-compatible pk helpers (replace __builtin_hcu_pk_*_f32)
inline __device__ __float2 hcu_pk_add_f32(__float2 a, __float2 b) {
__float2 o;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(o) : "v"(a), "v"(b));
return o;
}
inline __device__ __float2 hcu_pk_mul_f32(__float2 a, __float2 b) {
__float2 o;
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(o) : "v"(a), "v"(b));
return o;
}
inline __device__ __float2 hcu_pk_fma_f32(__float2 x, __float2 m, __float2 a) {
__float2 d;
asm volatile("v_pk_fma_f32 %0, %1, %2, %3" : "=v"(d) : "v"(x), "v"(m), "v"(a));
return d;
}
// DTK requires these control operands to remain compile-time constants.
template<bool Clamp = false, int OModifier = 0>
inline __device__ auto hcu_cvt_pk_f16_f32(float src0, float src1) {
static_assert(OModifier == 0, "Only o_modifier=0 is currently validated in HG DTK migration");
return __builtin_hcu_cvt_pk_f16_f32(0, src0, 0, src1, Clamp, OModifier);
}
template<bool Clamp = false>
inline __device__ auto hcu_cvt_pk_bf16_f32(float src0, float src1) {
return __builtin_hcu_cvt_pk_bf16_f32(0, src0, 0, src1, Clamp);
}
template<int ByteSel>
inline __device__ vec2_fp32 hcu_cvt_pk_f32_fp8(int src0) {
static_assert(ByteSel == 0 || ByteSel == 2, "ByteSel must select the low or high packed fp8 pair");
return __builtin_hcu_cvt_pk_f32_fp8(src0, false, 0, ByteSel);
}
// d = a * b + c
inline __device__ __float2 inlineasm_fa_v_pk_fma_f32(__float2 &a, const __float2& b, const __float2& c) {
__float2 d;
......@@ -874,7 +739,7 @@ inline __host__ __device__ auto DownCastPair(const vec2_Element<FromType>& sourc
template<>
inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__)
auto result = hcu_cvt_pk_f16_f32<false, 0>(source[0], source[1]);
auto result = __builtin_hcu_cvt_pk_f16_f32(source[0], source[1], false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result);
#else
return __builtin_amdgcn_cvt_pkrtz(source[0], source[1]);
......@@ -884,7 +749,7 @@ inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<f
template<>
inline __host__ __device__ auto DownCastPair<float, bhalf_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__)
auto result = hcu_cvt_pk_bf16_f32<false>(source[0], source[1]);
auto result = __builtin_hcu_cvt_pk_bf16_f32(source[0], source[1], false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result);
#else
vec2_Element<bhalf_t> result;
......@@ -904,7 +769,7 @@ inline __host__ __device__ auto DownCastPairNoPack(const FromType src0, const Fr
template<>
inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float src0, const float src1) {
#if defined(__gfx938__)
auto result = hcu_cvt_pk_f16_f32<false, 0>(src0, src1);
auto result = __builtin_hcu_cvt_pk_f16_f32(src0, src1, false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result);
#else
return __builtin_amdgcn_cvt_pkrtz(src0, src1);
......@@ -914,7 +779,7 @@ inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float sr
template<>
inline __host__ __device__ auto DownCastPairNoPack<float, bhalf_t>(const float src0, const float src1) {
#if defined(__gfx938__)
auto result = hcu_cvt_pk_bf16_f32<false>(src0, src1);
auto result = __builtin_hcu_cvt_pk_bf16_f32(src0, src1, false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result);
#else
vec2_Element<bhalf_t> result;
......
......@@ -7,35 +7,126 @@
#define VA_LIMIT_BITS(x) (0xffffffffffff & x)
template<int INSTM, int INSTNM, int T, int R>
__forceinline__ __device__ void matrix_load_b16_lds_trans_builtin(size_t lds_addr_warp, vec4_int rsrc, int moffset) {
#define MATRIX_LOAD_32X32_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
#endif
}
#define MATRIX_LOAD_32X32_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,,);
}
#endif
}
// ======================================================= MLS32x16 ===========================================================
#define MATRIX_LOAD_32X16_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int soffset = lds_addr_warp + 0x80000000;
if constexpr (INSTM == 32 && INSTNM == 16) {
__builtin_hcu_matrix_load_32x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
} else if constexpr (INSTM == 32 && INSTNM == 32) {
__builtin_hcu_matrix_load_32x32_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
} else if constexpr (INSTM == 64 && INSTNM == 16) {
__builtin_hcu_matrix_load_64x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
(void)moffset;
#endif
}
template<int INSTM, int INSTNM, int T, int R>
__forceinline__ __device__ void matrix_load_b16_lds_builtin(size_t lds_addr_warp, vec4_int rsrc, int moffset) {
#define MATRIX_LOAD_32X16_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int soffset = lds_addr_warp + 0x00000000;
if constexpr (INSTM == 32 && INSTNM == 16) {
__builtin_hcu_matrix_load_32x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
} else if constexpr (INSTM == 32 && INSTNM == 32) {
__builtin_hcu_matrix_load_32x32_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
} else if constexpr (INSTM == 64 && INSTNM == 16) {
__builtin_hcu_matrix_load_64x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0);
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,,);
}
(void)moffset;
#endif
}
......
#pragma once
#include "numeric_types.h"
// DTK: __builtin_hcu_matrix_load_*_b8 第二参为 addrspace(3) char*(本机 clang 报错 short* 与 char* 不匹配);b16 用 short* 见 intrinsic_mls_ds.h。
// 改法、soffset(trans +0x80000000)、调用方式与验证:见仓库根目录 ROCM指令迁移到DTK.md §4。
// Inline asm with "s"(vec4_uint) can lower srsrc to VGPR and fail with invalid operand; builtins keep srsrc in the correct class.
template<int r, int t>
__forceinline__ __device__ void matrix_load_128x16_b8_lds_trans_builtin(size_t lds_addr_warp, vec4_int rsrc, int /*matrix_offset*/) {
#if defined(__gfx938__)
int soffset = static_cast<int>(lds_addr_warp) + 0x80000000;
// Third arg must be compile-time constant (same pattern as matrix_load_b16); call sites use matrix_offset==0.
__builtin_hcu_matrix_load_128x16_b8(
rsrc,
(__attribute__((address_space(3))) char*)(soffset),
0,
t,
r,
0,
0);
#endif
}
#define MATRIX_LOAD_128X16_B8_LDS_TRANS(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 4\n\t" \
"matrix_load_128x16_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(MATRIX_OFFSET) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__)
union union_vec4_uint u;
u.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset);
matrix_load_128x16_b8_lds_trans_builtin<r, t>(lds_addr_warp, u.i32, matrix_offset);
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif
}
......@@ -50,28 +43,25 @@ __forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType
}
template<int r, int t>
__forceinline__ __device__ void matrix_load_64x32_b8_lds_rearrange_builtin(size_t lds_addr_warp, vec4_int rsrc, int /*matrix_offset*/) {
#if defined(__gfx938__)
int soffset = static_cast<int>(lds_addr_warp);
__builtin_hcu_matrix_load_64x32_b8(
rsrc,
(__attribute__((address_space(3))) char*)(soffset),
0,
t,
r,
0,
0);
#endif
}
#define MATRIX_LOAD_64x32_B8_LDS_REARRANGE(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
asm volatile("s_nop 4\n\t" \
"matrix_load_64x32_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(MATRIX_OFFSET) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_64x32_b8_lds_rearrange(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__)
union union_vec4_uint u;
u.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset);
matrix_load_64x32_b8_lds_rearrange_builtin<r, t>(lds_addr_warp, u.i32, matrix_offset);
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif
}
......@@ -120,3 +110,9 @@ inline __device__ vec4_fp32 mmac_4interleave_b8(const vec8_Element<T> &v1, const
{
return __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(v1, v2, v3, 1, 0);
}
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac_4interleave_fp8(const vec8_Element<T> &v1, const vec8_Element<T> &v2, const vec4_fp32 &v3)
{
return __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(v1, v2, v3, 1, 0);
}
File mode changed from 100644 to 100755
......@@ -20,8 +20,9 @@ __forceinline__ __device__ void kvcache_epilogue_store_output_gfx938(
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1);
// Specialized optimization for headdim 128. Dim256 is split into two
// 128-column stores so it can use the same layout per split.
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and K_LOOP_COUNT == WARP_NUM);
if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return;
}
......@@ -90,8 +91,9 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
auto gO = prepare_for_buffer_load<kHeadDimV, SplitkvAccumType, false/*USE_CACHE_SWIZZLE*/>(o_ptr);
int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1);
// Specialized optimization for headdim 128. Dim256 is split into two
// 128-column stores so it can use the same layout per split.
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and K_LOOP_COUNT == WARP_NUM);
if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return;
}
......@@ -124,3 +126,39 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention epilogue helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_epilogue_rescale_acco_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum v_descale) {
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? v_descale : v_descale / sum;
__float2 scale_pair = {inv_sum, inv_sum};
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#pragma unroll
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] =
__builtin_hcu_pk_mul_f32(acc_o[tile_32x32_id][mmac_id].u64[vec_id], scale_pair);
}
}
}
}
}
}
}
......@@ -66,10 +66,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + v_mls_lds_warp_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
......
......@@ -2,6 +2,7 @@
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
#include "intrinsic_mls_ds_b8.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
......@@ -47,10 +48,376 @@ __forceinline__ __device__ void kvcache_prefetch_v_to_lds_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0);
union union_vec4_uint v_rsrc_bits;
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + v_mls_lds_warp_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention PV helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_k_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_row_stride,
int max_seq_k_offset);
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
constexpr int PREFETCH = 2;
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 0;
constexpr int k_loop = K_LOOP_COUNT - 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
}
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element, int load_id>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_one_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset) {
static_assert(K_LOOP_COUNT == 2);
static_assert(load_id == 0 || load_id == 1);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
constexpr int stage_id = 0;
constexpr int k_loop = K_LOOP_COUNT - 1;
const int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(Element);
const int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(Element);
const int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
const int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
const int warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(Element);
const int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
template<bool PrefetchK, int K_LOOP_COUNT, int kBlockK, int kBlockN, int M_WARP_COUNT, int K_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT, typename V_Element, typename P_Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_pv_gemm_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint& k_addr,
V_Element* v_lds,
V_Element* k_lds,
union_vec2_f16x2<P_Element> p_reg[M_WARP_COUNT * K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_v_offset,
int64_t k_addr_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
constexpr int PREFETCH = 2;
flash::wait_lds_data_arrived<true/*sync*/>(0);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 1;
#pragma unroll
for (int k_loop = K_LOOP_COUNT - 1 - PREFETCH; k_loop >= 1; k_loop -= PREFETCH) {
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
flash::wait_buffer_data_arrived<false/*sync*/>(PREFETCH);
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
vec2_fp32 v_f32x2[4];
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
flash::wait_buffer_data_arrived<false/*sync*/>(0);
constexpr bool PrefetchKInPV = PrefetchK && K_LOOP_COUNT == 2;
{
constexpr int k_loop = 1 - PREFETCH;
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
if constexpr (PrefetchKInPV) {
if (load_id == 0 && tile32x32_id == 1) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
vec2_fp32 v_f32x2[4];
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (PrefetchK && !PrefetchKInPV) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
template<bool PrefetchK, int K_LOOP_COUNT, int kBlockK, int kBlockN, int M_WARP_COUNT, int K_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT, typename V_Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint& k_addr,
V_Element* v_lds,
V_Element* k_lds,
union_vec32_fp8 p_reg[M_MMAC_COUNT],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_v_offset,
int64_t k_addr_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
static_assert(M_WARP_COUNT == 1);
static_assert(K_WARP_COUNT == 2);
constexpr int PREFETCH = 2;
flash::wait_lds_data_arrived<true/*sync*/>(0);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 1;
#pragma unroll
for (int k_loop = K_LOOP_COUNT - 1 - PREFETCH; k_loop >= 1; k_loop -= PREFETCH) {
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
flash::wait_buffer_data_arrived<false/*sync*/>(PREFETCH);
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave_fp8<int8_t, ElementAccum>(
p_reg[min_tile_m].i8x8[0],
v_regs[tile32x32_id].i8x8[min_tile_dim],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
flash::wait_buffer_data_arrived<false/*sync*/>(0);
constexpr bool PrefetchKInPV = PrefetchK && K_LOOP_COUNT == 2;
{
constexpr int k_loop = 1 - PREFETCH;
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
if constexpr (PrefetchKInPV) {
if (load_id == 0 && tile32x32_id == 1) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave_fp8<int8_t, ElementAccum>(
p_reg[min_tile_m].i8x8[0],
v_regs[tile32x32_id].i8x8[min_tile_dim],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (PrefetchK && !PrefetchKInPV) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
template <int M_MMAC_COUNT, typename Element, typename ElementAccum>
inline __device__ void fp8_kvcache_cvt_f32_to_fp8_gfx938(
union_vec32_fp8 p_reg[M_MMAC_COUNT],
vec4_Accum<ElementAccum> s_reg[1][4]) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[0][0 * 2 + min_tile_m].f32, p_reg[min_tile_m].i32[0]);
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[0][1 * 2 + min_tile_m].f32, p_reg[min_tile_m].i32[1]);
}
}
......@@ -73,10 +73,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
}
// 等待 MLS 数据回来
......@@ -272,3 +269,4 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
}
} // qk_gemm
......@@ -3,6 +3,7 @@
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "kvcache_pv_gemm_utils_gfx938.h"
#include "intrinsic_mls_ds_b8.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT>
......@@ -40,10 +41,7 @@ __forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_gfx938(
}
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
flash::wait_lds_data_arrived<true>(0);
union union_vec4_uint q_rsrc_bits;
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(q_lds, q_srsrc, lds_offset_bytes, 0);
flash::wait_buffer_data_arrived<true>(0);
......@@ -199,11 +197,175 @@ __forceinline__ __device__ void kvcache_prefetch_k_to_lds_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
union union_vec4_uint k_rsrc_bits;
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_k_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_row_stride,
int max_seq_k_offset) {
int stage_id = 0;
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
constexpr int k_loop = 0;
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset);
k_srsrc[3] = (nm_filter << 8) + 0x40000;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element, int load_id>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_one_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset);
template<bool PrefetchVInQK, int K_LOOP_COUNT, int kHeadDim, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_qk_gemm_gfx938(
vec4_uint k_addr,
vec4_uint v_addr,
Element* k_lds,
Element* v_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_k_offset = 0) {
static_assert(!PrefetchVInQK || (kHeadDim == 128 && K_LOOP_COUNT == 2));
int stage_id = 0;
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
asm volatile(
"v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
: "=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[0]),
"=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
}
}
}
stage_id ^= 1;
#pragma unroll
for (int k_loop = 1; k_loop < kHeadDim / 64; ++k_loop) {
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset);
k_srsrc[3] = (nm_filter << 8) + 0x40000;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
flash::wait_buffer_data_arrived<false/*sync*/>(1);
stage_id ^= 1;
union_vec16_fp8 k_regs[WARP_N / 16];
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
if constexpr (PrefetchVInQK) {
if (min_tile_n == 1) {
fp8_kvcache_prefetch_v_one_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element, 0>(
v_addr, v_lds, warp_id, v_row_stride, max_seq_k_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 =
mmac_4interleave_fp8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
{
constexpr int k_loop = kHeadDim / 64;
if constexpr (PrefetchVInQK) {
flash::wait_buffer_data_arrived<false/*sync*/>(1);
} else {
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
stage_id ^= 1;
union_vec16_fp8 k_regs[WARP_N / 16];
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
if constexpr (PrefetchVInQK) {
if (min_tile_n == 1) {
fp8_kvcache_prefetch_v_one_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element, 1>(
v_addr, v_lds, warp_id, v_row_stride, max_seq_k_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 =
mmac_4interleave_fp8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
......@@ -65,6 +65,44 @@ inline __device__ void kvcache_apply_mask_causal_gfx938(DataType tensor[M_WARP_C
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_local_causal_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q,
const int ngroups, const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int logical_row = row_idx / ngroups;
const int logical_q = max_seqlen_q / ngroups;
const int col_idx_limit_left = max(0, logical_row + max_seqlen_k - logical_q - window_size_left);
const int col_idx_limit_right = min(max_seqlen_k, logical_row + max_seqlen_k - logical_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] =
(col_idx < col_idx_limit_left || col_idx > col_idx_limit_right)
? -INFINITY
: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
......@@ -96,3 +134,26 @@ inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WA
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention score helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_kvcache_apply_descale_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4],
const __float2 qk_descale) {
#pragma unroll
for (int i = 0; i < M_WARP_COUNT * N_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] =
__builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] =
__builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
}
}
}
}
......@@ -73,13 +73,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -92,13 +92,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -111,13 +111,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -130,13 +130,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -149,13 +149,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -168,13 +168,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
}
// asm volatile("s_nop 8\n");
{
......@@ -187,13 +187,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
}
// 先写一部分数据到 lds
for (int loop_id = 0; loop_id < 7; ++loop_id) {
......@@ -206,13 +206,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(2)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(1)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
__builtin_amdgcn_sched_barrier(0);
acc_o_lds[lds_offset[loop_id]] = acc_tmp_wave0[loop_id].f32[0];
acc_o_lds[lds_offset[loop_id] + 16] = acc_tmp_wave0[loop_id].f32[1];
......@@ -233,26 +233,23 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 0*16 + (lane_id>>4)*4 + WARP_ID;
int lds_offset1 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 1*16 + (lane_id>>4)*4 + WARP_ID;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0, acc_tmp.u64, 0, 16);
acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
// acc_tmp.f32[0] = acc_o_lds[lds_offset0];
// acc_tmp.f32[1] = acc_o_lds[lds_offset1];
union_vec2_fp32 acc_tmp_wave1;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave1.f32[0] = acc_o_lds[lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave1.f32[1] = acc_o_lds[lds_offset1 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave2.f32[0] = acc_o_lds[lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave2.f32[1] = acc_o_lds[lds_offset1 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
// acc_tmp_wave3.f32[0] = acc_o_lds[lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave3.f32[1] = acc_o_lds[lds_offset1 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment