Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
...@@ -36,7 +36,7 @@ __forceinline__ __device__ void prefill_mla_epilugue_rescale_acco( ...@@ -36,7 +36,7 @@ __forceinline__ __device__ void prefill_mla_epilugue_rescale_acco(
const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi; const int pv_tile_id = pv_n_loop * (WARP_M / (16 * M_MMAC_COUNT)) * (kBlockK / 32) + ni * (WARP_M / (16 * M_MMAC_COUNT)) + mi;
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
for(int vec_id = 0; vec_id < 2; ++vec_id) { for(int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[pv_tile_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32( acc_o[pv_tile_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_id], acc_o[pv_tile_id][mmac_id].u64[vec_id],
scale_pair scale_pair
); );
......
...@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512( ...@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
// DS // DS
...@@ -136,10 +133,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512( ...@@ -136,10 +133,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
} }
stage_id ^= 1; stage_id ^= 1;
...@@ -200,10 +194,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512( ...@@ -200,10 +194,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
} }
int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
} }
lds_stage_id ^= 1; lds_stage_id ^= 1;
...@@ -309,4 +300,4 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512( ...@@ -309,4 +300,4 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset); prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK); prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
} }
} }
\ No newline at end of file
...@@ -29,9 +29,6 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512( ...@@ -29,9 +29,6 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512(
int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES; int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束 flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + lds_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -106,10 +106,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512( ...@@ -106,10 +106,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8; q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES; int lds_offset = (q_stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
if (k_even) { if (k_even) {
k_stage_id ^= 1; k_stage_id ^= 1;
...@@ -122,10 +119,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512( ...@@ -122,10 +119,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8; k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
} }
...@@ -324,4 +318,4 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512( ...@@ -324,4 +318,4 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
#endif #endif
} }
} // qk_gemm } // qk_gemm
\ No newline at end of file
...@@ -36,10 +36,7 @@ __forceinline__ __device__ void prefetch_q_to_lds_mls_ds_576_512( ...@@ -36,10 +36,7 @@ __forceinline__ __device__ void prefetch_q_to_lds_mls_ds_576_512(
int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES; int lds_offset = (stage_id * kBlockM * kBlockK + warp_id * 16 * 32) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); // pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖 flash::wait_all_warp_arrived(); // pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
} }
} }
...@@ -71,9 +68,6 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512( ...@@ -71,9 +68,6 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512(
} }
int lds_offset = (stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES; int lds_offset = (stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset;
matrix_load_b16_lds_trans_builtin<32, 16, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
\ No newline at end of file
...@@ -16,10 +16,10 @@ struct PrefillMlaAllreduce { ...@@ -16,10 +16,10 @@ struct PrefillMlaAllreduce {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32); res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32); res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = hcu_pk_add_f32(x.u64, res.u64); x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_xor_tmp(x.f32[0], 16); res.f32[0] = __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 16); res.f32[1] = __shfl_xor_tmp(x.f32[1], 16);
res.u64 = hcu_pk_add_f32(res.u64, x.u64); res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else #else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32); x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32); x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
...@@ -113,7 +113,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR ...@@ -113,7 +113,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
if constexpr (M_MMAC_COUNT == 2){ if constexpr (M_MMAC_COUNT == 2){
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32( summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64, summary[m_idx * 2].u64,
additem_pair additem_pair
); );
...@@ -159,7 +159,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR ...@@ -159,7 +159,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread
if constexpr (M_MMAC_COUNT == 2) { if constexpr (M_MMAC_COUNT == 2) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32( summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64, summary_cur[m_idx * 2].u64,
additem_pair additem_pair
); );
...@@ -275,13 +275,12 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / (16 * M_MMAC_ ...@@ -275,13 +275,12 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / (16 * M_MMAC_
int qk_tile_id = mi + ni * (WARP_M / (16 * M_MMAC_COUNT)); int qk_tile_id = mi + ni * (WARP_M / (16 * M_MMAC_COUNT));
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) { for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[qk_tile_id][mmac_id].u64[vec_idx] = hcu_pk_fma_f32( tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[qk_tile_id][mmac_id].u64[vec_idx], tensor[qk_tile_id][mmac_id].u64[vec_idx],
scale_pair, scale_pair,
neg_max_scaled_pair neg_max_scaled_pair
); );
} }
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]); tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]);
} }
...@@ -343,7 +342,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N / ...@@ -343,7 +342,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll #pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) { for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_tile_id][mmac_id].u64[vec_idx] = hcu_pk_mul_f32( acc_o[pv_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_tile_id][mmac_id].u64[vec_idx], acc_o[pv_tile_id][mmac_id].u64[vec_idx],
scores_scale_pair scores_scale_pair
); );
...@@ -374,7 +373,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N / ...@@ -374,7 +373,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) { for (int mi = 0; mi < (WARP_M / (16 * M_MMAC_COUNT)); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
if constexpr (M_MMAC_COUNT == 2) { if constexpr (M_MMAC_COUNT == 2) {
scores_sum[mi].u64 = hcu_pk_add_f32( scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64, scores_sum[mi].u64,
scores_sum_cur[mi].u64 scores_sum_cur[mi].u64
); );
......
File mode changed from 100644 to 100755
...@@ -33,10 +33,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio ...@@ -33,10 +33,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element)); *(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
// matrix load // matrix load
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -63,10 +60,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio ...@@ -63,10 +60,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
int q_warp_offset = (LOAD * WARP_NUM + real_warp_id) * 32; int q_warp_offset = (LOAD * WARP_NUM + real_warp_id) * 32;
*(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element)); *(uint64_t*)&q_srsrc = VA_LIMIT_BITS(*(uint64_t*)&q_addr + q_warp_offset * sizeof(Element));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x16_b16_lds_trans<0, 1>(q_lds, q_srsrc, lds_offset_bytes, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 16, 1, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// continue from MID // continue from MID
...@@ -89,4 +83,4 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio ...@@ -89,4 +83,4 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
// wait all data written to registers // wait all data written to registers
flash::wait_lds_data_arrived<true/*sync*/>(0); flash::wait_lds_data_arrived<true/*sync*/>(0);
} }
\ No newline at end of file
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -22,7 +22,7 @@ __forceinline__ __device__ void mla_epilugue_rescale_acco( ...@@ -22,7 +22,7 @@ __forceinline__ __device__ void mla_epilugue_rescale_acco(
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi); int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
for (int vec_id = 0; vec_id < 2; ++vec_id) { for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32( acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id], acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair scale_pair
); );
......
File mode changed from 100644 to 100755
...@@ -428,8 +428,8 @@ __forceinline__ __device__ void mla_prefix_prefill_combine_s_reg_of_2waves(vec4_ ...@@ -428,8 +428,8 @@ __forceinline__ __device__ void mla_prefix_prefill_combine_s_reg_of_2waves(vec4_
: ((warp_id & 1) ? warp_id - 1: warp_id + 1); : ((warp_id & 1) ? warp_id - 1: warp_id + 1);
int lds_load_offset = n_loop * WARP_NUM * (64 * 4) + warp_id_symmetry * 64 * 4 + lane_id * 4; int lds_load_offset = n_loop * WARP_NUM * (64 * 4) + warp_id_symmetry * 64 * 4 + lane_id * 4;
vec4_Accum<ElementAccum> symmetry_data = *(vec4_Accum<ElementAccum>*)(s_reg_lds + lds_load_offset); vec4_Accum<ElementAccum> symmetry_data = *(vec4_Accum<ElementAccum>*)(s_reg_lds + lds_load_offset);
s_reg[m_idx][n_loop].u64[0] = hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[0], symmetry_data.u64[0]); s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[0], symmetry_data.u64[0]);
s_reg[m_idx][n_loop].u64[1] = hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[1], symmetry_data.u64[1]); s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_add_f32(s_reg[m_idx][n_loop].u64[1], symmetry_data.u64[1]);
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
__syncthreads(); __syncthreads();
...@@ -471,9 +471,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax( ...@@ -471,9 +471,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scale_softmax_log2_pair[1] = scale_softmax_log2; scale_softmax_log2_pair[1] = scale_softmax_log2;
#pragma unroll #pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) { for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
s_reg[m_idx][n_loop].u64[0] = hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[0], scale_softmax_log2_pair, max_scaled); s_reg[m_idx][n_loop].u64[0] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[0], scale_softmax_log2_pair, max_scaled);
s_reg[m_idx][n_loop].u64[1] = hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[1], scale_softmax_log2_pair, max_scaled); s_reg[m_idx][n_loop].u64[1] = __builtin_hcu_pk_fma_f32(s_reg[m_idx][n_loop].u64[1], scale_softmax_log2_pair, max_scaled);
asm volatile("s_nop 0" ::: "memory");
s_reg[m_idx][n_loop].f32[0] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[0]); s_reg[m_idx][n_loop].f32[0] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[0]);
s_reg[m_idx][n_loop].f32[1] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[1]); s_reg[m_idx][n_loop].f32[1] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[1]);
s_reg[m_idx][n_loop].f32[2] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[2]); s_reg[m_idx][n_loop].f32[2] = __llvm_exp2_f32(s_reg[m_idx][n_loop].f32[2]);
...@@ -489,8 +488,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax( ...@@ -489,8 +488,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum_pair[1] = 0; scores_sum_pair[1] = 0;
#pragma unroll #pragma unroll
for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) { for (int n_loop = 0; n_loop < kBlockN / 16; ++n_loop) {
scores_sum_pair = hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[0]); scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[0]);
scores_sum_pair = hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[1]); scores_sum_pair = __builtin_hcu_pk_add_f32(scores_sum_pair, s_reg[m_idx][n_loop].u64[1]);
} }
scores_sum_cur[m_idx] = scores_sum_pair[0] + scores_sum_pair[1]; scores_sum_cur[m_idx] = scores_sum_pair[0] + scores_sum_pair[1];
scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor(scores_sum_cur[m_idx], 32); scores_sum_cur[m_idx] = scores_sum_cur[m_idx] + __shfl_xor(scores_sum_cur[m_idx], 32);
...@@ -505,8 +504,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax( ...@@ -505,8 +504,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum[m_idx] *= scores_scale[0]; scores_sum[m_idx] *= scores_scale[0];
#pragma unroll #pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit; ++pv_tile) { for (int pv_tile = 0; pv_tile < kHeadDimVSplit; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], scores_scale); acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], scores_scale);
acc_o[m_idx][pv_tile].u64[1] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], scores_scale); acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], scores_scale);
} }
} }
// update max/sum // update max/sum
...@@ -932,8 +931,8 @@ __forceinline__ __device__ void mla_prefix_prefill_rescale_acc_o( ...@@ -932,8 +931,8 @@ __forceinline__ __device__ void mla_prefix_prefill_rescale_acc_o(
inv_sum[1] = inv_sum[0]; inv_sum[1] = inv_sum[0];
#pragma unroll #pragma unroll
for (int pv_tile = 0; pv_tile < kHeadDimVSplit / 16; ++pv_tile) { for (int pv_tile = 0; pv_tile < kHeadDimVSplit / 16; ++pv_tile) {
acc_o[m_idx][pv_tile].u64[0] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], inv_sum); acc_o[m_idx][pv_tile].u64[0] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[0], inv_sum);
acc_o[m_idx][pv_tile].u64[1] = hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], inv_sum); acc_o[m_idx][pv_tile].u64[1] = __builtin_hcu_pk_mul_f32(acc_o[m_idx][pv_tile].u64[1], inv_sum);
} }
} }
} }
......
File mode changed from 100644 to 100755
...@@ -74,7 +74,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32( ...@@ -74,7 +74,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4; precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
} }
} }
} }
...@@ -97,7 +97,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32( ...@@ -97,7 +97,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll #pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET); v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
} }
} }
} }
...@@ -220,7 +220,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32( ...@@ -220,7 +220,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4; precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
} }
} }
} }
...@@ -243,7 +243,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32( ...@@ -243,7 +243,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll #pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET); v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
} }
} }
} }
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment