Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
......@@ -236,7 +236,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary[m_idx*2].u64 = hcu_pk_add_f32(
summary[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx*2].u64,
additem_pair
);
......@@ -270,7 +270,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary_cur[m_idx*2].u64 = hcu_pk_add_f32(
summary_cur[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx*2].u64,
additem_pair
);
......@@ -372,13 +372,12 @@ inline __device__ void int8_kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M/32
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if defined(__gfx936__) || defined(__gfx938__)
for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32(
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
}
......@@ -482,7 +481,7 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
......@@ -535,7 +534,7 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop*WARP_M + mi*32 + lane_id*2);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
......@@ -560,7 +559,7 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......
......@@ -38,20 +38,17 @@ __forceinline__ __device__ void kvcache_acco_reduce(
union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 0 * 16 + (lane_id >> 4) * 4 + WARP_ID;
int lds_offset1 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 1 * 16 + (lane_id >> 4) * 4 + WARP_ID;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0, acc_tmp.u64, 0, 16);
acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
union_vec2_fp32 acc_tmp_wave1;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 1 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave1.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 2 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave2.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 3 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave3.u64, 0, 16);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
acc_tmp.f32[1] += acc_tmp_wave3.f32[1];
// ds_write2_b32
......
......@@ -10,7 +10,7 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
int warp_id,
int lane_id) {
#if defined(__gfx938__)
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and Padding == 0); // Specialized optimizatio for headdim 128
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and Padding == 0 and K_LOOP_COUNT == WARP_NUM and K_WARP_COUNT == 1 and M_WARP_COUNT == 1); // Specialized optimization for headdim 128
#else
constexpr int OPT_FOR_HDIM128 = false; // keep same as origin for archs <= gfx936
#endif
......@@ -90,7 +90,8 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int lds_offset = (warp_id * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT + q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4/*0~3*/) * 4/*0~15*/;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32;
int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT + k_idx * M_WARP_COUNT;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32;
}
}
}
......@@ -135,7 +136,8 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4;
acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT + k_idx * M_WARP_COUNT;
acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
}
}
}
......
......@@ -23,7 +23,7 @@ __forceinline__ __device__ void kvcache_epilugue_rescale_acco(
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
......
File mode changed from 100644 to 100755
......@@ -74,7 +74,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4;
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
......@@ -97,7 +97,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
......@@ -168,7 +168,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4;
precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
}
}
}
......@@ -191,7 +191,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET);
v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
}
}
}
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -227,7 +227,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -262,7 +262,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -365,13 +365,12 @@ inline __device__ void kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M / 32) *
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32(
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
#pragma unroll
for(int vec_idx = 0; vec_idx < 4; ++vec_idx) {
tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
......@@ -451,7 +450,7 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; ++vec_idx) {
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + (mi + ni * (WARP_M / 32))][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
......@@ -504,7 +503,7 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
for(int warp_loop = 1; warp_loop < WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
......@@ -529,7 +528,7 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
for (int mi = 0; mi < (WARP_M / 32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......
File mode changed from 100644 to 100755
......@@ -104,8 +104,8 @@ __forceinline__ __device__ void fp8_mla_acco_reduce_tile16x32(
data.f32[1] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 1 * 64];
data.f32[2] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 2 * 64];
data.f32[3] = acc_o_lds[neighbor * 2048 + warp_id * 2 * 16 * 16 + min_tile_n * 16 * 16 + lane_id + 3 * 64];
acc_o[k_loop + 0][min_tile_n * 2].u64[0] = hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], data.u64[0]);
acc_o[k_loop + 0][min_tile_n * 2].u64[1] = hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], data.u64[1]);
acc_o[k_loop + 0][min_tile_n * 2].u64[0] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[0], data.u64[0]);
acc_o[k_loop + 0][min_tile_n * 2].u64[1] = __builtin_hcu_pk_add_f32(acc_o[k_loop + 0][min_tile_n * 2].u64[1], data.u64[1]);
}
}
__syncthreads();
......
......@@ -24,7 +24,7 @@ __forceinline__ __device__ void fp8_mla_epilugue_rescale_acco_gfx938(
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__)
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair
);
......
......@@ -75,8 +75,8 @@ inline __device__ void fp8_mla_apply_descale_gfx938(DataType tensor[M_WARP_COUNT
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] = hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] = hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
}
}
}
......
......@@ -88,16 +88,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[1] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[2] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[3] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[0][0], v_f32x2[0][1]);
v_f16x8.f16x2[1] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[1][0], v_f32x2[1][1]);
v_f16x8.f16x2[2] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[2][0], v_f32x2[2][1]);
v_f16x8.f16x2[3] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[3][0], v_f32x2[3][1]);
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
......@@ -151,16 +151,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
// fp8 -> f32
vec2_fp32 v_f32x2[4]; // 8 fp8 -> 8 f32, for 1 mmac
v_f32x2[0] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[1] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0]);
v_f32x2[2] = hcu_cvt_pk_f32_fp8<0>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[3] = hcu_cvt_pk_f32_fp8<2>(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1]);
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
// f32 -> fp16
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[0][0], v_f32x2[0][1]);
v_f16x8.f16x2[1] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[1][0], v_f32x2[1][1]);
v_f16x8.f16x2[2] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[2][0], v_f32x2[2][1]);
v_f16x8.f16x2[3] = hcu_cvt_pk_f16_f32<false, 0>(v_f32x2[3][0], v_f32x2[3][1]);
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
// mmac_16x16x16, 4 fp16
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment