Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -125,7 +125,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32(
summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64,
additem_pair
);
......@@ -159,7 +159,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32(
summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64,
additem_pair
);
......@@ -260,13 +260,12 @@ inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WA
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__)
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair,
neg_max_scaled_pair
);
}
asm volatile("s_nop 0" ::: "memory");
for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
}
......@@ -350,7 +349,7 @@ inline __device__ void mla_softmax_rescale_o(
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair
);
......@@ -402,7 +401,7 @@ inline __device__ void mla_softmax_rescale_o(
for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else
cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1];
......@@ -426,7 +425,7 @@ inline __device__ void mla_softmax_rescale_o(
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32(
scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64,
scores_sum_cur[mi].u64
);
......
File mode changed from 100644 to 100755
......@@ -79,8 +79,6 @@ union union_vec_fp32 {
union union_vec4_uint {
unsigned long long u64[2]; // 128 bits
uint4 u32;
vec4_int i32;
vec4_uint v32;
uint8_t u8[16];
};
......@@ -261,3 +259,4 @@ __forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_
// return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)};
#endif
}
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -158,7 +158,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
params.window_size_right);
}
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, false/*IsInference*/>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
......@@ -227,7 +227,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
}
// TODO: when we have key_padding_mask we'll need to Check_inf
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, false/*IsInference*/>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block * (kBlockN / WARP_N);
......@@ -255,7 +255,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, true/*Is_Interleaved*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, WARP_ID, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, false/*Is_Interleaves*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, WARP_ID, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
......@@ -458,7 +458,7 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
// convertType: float2half
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, true/*IsInference*/>(p_reg, s_reg);
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
Is_even_MN
? PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, WARP_ID, seqlen_k_stride, seqlen_v_stride, 0)
......@@ -1116,7 +1116,7 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params &params, c
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_gfx938<kHeadDimPVCompute, kBlockM, kBlockK, WARP_M, TailTile16, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
fwd_epilogue_store_output_gfx938<kHeadDimPVCompute, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
......@@ -1397,7 +1397,7 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_gfx938<kHeadDimV, kBlockM, kBlockK, WARP_M, 2/*TailTile16*/, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
fwd_epilogue_store_output_gfx938<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
}
......
......@@ -1132,22 +1132,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......@@ -1391,22 +1385,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......@@ -1662,22 +1650,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
......@@ -1767,4 +1749,4 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
#endif
}
} // namespace flash
} // namespace flash
\ No newline at end of file
......@@ -646,6 +646,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
inline_utcl2_warmup_dword(k_addr);
inline_utcl2_warmup_dword(v_addr);
}
// Keep warmup buffer loads out of the MLS vmcnt schedule below.
flash::wait_all_buffer_data_arrived<true>();
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
int row_offset_lse;
......@@ -1268,4 +1270,4 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_kernel(Params pa
}
}
}
\ No newline at end of file
}
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment