Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -125,7 +125,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT ...@@ -125,7 +125,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary[m_idx * 2].u64 = hcu_pk_add_f32( summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx * 2].u64, summary[m_idx * 2].u64,
additem_pair additem_pair
); );
...@@ -159,7 +159,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT ...@@ -159,7 +159,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
__float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * M_WARP_COUNT][min_tile_n * 2 + 1].f32[vec_idx]};
summary_cur[m_idx * 2].u64 = hcu_pk_add_f32( summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx * 2].u64, summary_cur[m_idx * 2].u64,
additem_pair additem_pair
); );
...@@ -260,13 +260,12 @@ inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WA ...@@ -260,13 +260,12 @@ inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WA
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
for (int vec_idx = 0; vec_idx < 2; vec_idx++) { for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32( tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx], tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scale_pair, scale_pair,
neg_max_scaled_pair neg_max_scaled_pair
); );
} }
asm volatile("s_nop 0" ::: "memory");
for (int vec_idx = 0; vec_idx < 4; vec_idx++) { for (int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]); tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx]);
} }
...@@ -350,7 +349,7 @@ inline __device__ void mla_softmax_rescale_o( ...@@ -350,7 +349,7 @@ inline __device__ void mla_softmax_rescale_o(
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll #pragma unroll
for (int vec_idx = 0; vec_idx < 2; vec_idx++) { for (int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32( acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx], acc_o[loop_id][min_tile_n * 2 + min_tile_m].u64[vec_idx],
scores_scale_pair scores_scale_pair
); );
...@@ -402,7 +401,7 @@ inline __device__ void mla_softmax_rescale_o( ...@@ -402,7 +401,7 @@ inline __device__ void mla_softmax_rescale_o(
for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) { for (int warp_loop = 1; warp_loop < WARP_NUM; ++warp_loop) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2); __float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop * WARP_M + mi * 32 + lane_id * 2);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum); cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else #else
cur_wave_sum[0] += other_warp_sum[0]; cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1]; cur_wave_sum[1] += other_warp_sum[1];
...@@ -426,7 +425,7 @@ inline __device__ void mla_softmax_rescale_o( ...@@ -426,7 +425,7 @@ inline __device__ void mla_softmax_rescale_o(
for (int mi = 0; mi < M_WARP_COUNT; ++mi) { for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
scores_sum[mi].u64 = hcu_pk_add_f32( scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64, scores_sum[mi].u64,
scores_sum_cur[mi].u64 scores_sum_cur[mi].u64
); );
......
File mode changed from 100644 to 100755
...@@ -79,8 +79,6 @@ union union_vec_fp32 { ...@@ -79,8 +79,6 @@ union union_vec_fp32 {
union union_vec4_uint { union union_vec4_uint {
unsigned long long u64[2]; // 128 bits unsigned long long u64[2]; // 128 bits
uint4 u32; uint4 u32;
vec4_int i32;
vec4_uint v32;
uint8_t u8[16]; uint8_t u8[16];
}; };
...@@ -261,3 +259,4 @@ __forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_ ...@@ -261,3 +259,4 @@ __forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_
// return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)}; // return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)};
#endif #endif
} }
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -158,7 +158,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in ...@@ -158,7 +158,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
params.window_size_right); params.window_size_right);
} }
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, false/*IsInference*/>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2); softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if constexpr (Is_dropout and Is_training) { if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N); warp_idx_for_dropout.u32.y = n_block_loop * (kBlockN / WARP_N);
...@@ -227,7 +227,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in ...@@ -227,7 +227,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
} }
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, false/*IsInference*/>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2); softmax_rescale_o<false, Is_causal || Is_local, vec4_Accum<ElementAccum>, vec2_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if constexpr (Is_dropout and Is_training) { if constexpr (Is_dropout and Is_training) {
warp_idx_for_dropout.u32.y = n_block * (kBlockN / WARP_N); warp_idx_for_dropout.u32.y = n_block * (kBlockN / WARP_N);
...@@ -255,7 +255,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in ...@@ -255,7 +255,7 @@ inline __device__ void compute_attn_mha_1rowblock(const Params &params, const in
} }
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o; Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, true/*Is_Interleaved*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, WARP_ID, lane_id, seqlen_o_stride, binfo.actual_seqlen_q); fwd_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, false/*Is_Interleaves*/, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, WARP_ID, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
} }
...@@ -458,7 +458,7 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p ...@@ -458,7 +458,7 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4]; union_vec2_f16x2<Element> p_reg[(WARP_M / 32) * (kBlockN / 32)][4];
// convertType: float2half // convertType: float2half
convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, true/*IsInference*/>(p_reg, s_reg); convert_pk_type<WARP_M, kBlockN, Element, ElementAccum>(p_reg, s_reg);
Is_even_MN Is_even_MN
? PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, WARP_ID, seqlen_k_stride, seqlen_v_stride, 0) ? PV_GEMM_FUNC(gV, gK, v_lds, k_lds, p_reg, acc_o, WARP_ID, seqlen_k_stride, seqlen_v_stride, 0)
...@@ -1116,7 +1116,7 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params &params, c ...@@ -1116,7 +1116,7 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params &params, c
} }
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o; Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_gfx938<kHeadDimPVCompute, kBlockM, kBlockK, WARP_M, TailTile16, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q); fwd_epilogue_store_output_gfx938<kHeadDimPVCompute, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
} }
...@@ -1397,7 +1397,7 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params ...@@ -1397,7 +1397,7 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
} }
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o; Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
fwd_epilogue_store_output_gfx938<kHeadDimV, kBlockM, kBlockK, WARP_M, 2/*TailTile16*/, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q); fwd_epilogue_store_output_gfx938<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, binfo.actual_seqlen_q);
} }
......
...@@ -1132,22 +1132,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co ...@@ -1132,22 +1132,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0; scores_sum[i].f32[0] = 0;
} }
uint64_t pk_zero = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) { for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll #pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll #pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) #if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#else #else
acc_o[i][min_tile_n].f32[0] = 0; acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0; acc_o[i][min_tile_n].f32[1] = 0;
...@@ -1391,22 +1385,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co ...@@ -1391,22 +1385,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum[i].f32[0] = 0; scores_sum[i].f32[0] = 0;
} }
uint64_t pk_zero = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) { for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll #pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll #pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) #if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#else #else
acc_o[i][min_tile_n].f32[0] = 0; acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0; acc_o[i][min_tile_n].f32[1] = 0;
...@@ -1662,22 +1650,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con ...@@ -1662,22 +1650,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
scores_sum[i].f32[0] = 0; scores_sum[i].f32[0] = 0;
} }
uint64_t pk_zero = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) { for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll #pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll #pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#if defined(__gfx936__) #if defined(__gfx936__) || defined(__gfx938__)
acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(0); acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#elif defined(__gfx938__)
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[0])
:);
asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n].u64[1])
:);
#else #else
acc_o[i][min_tile_n].f32[0] = 0; acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0; acc_o[i][min_tile_n].f32[1] = 0;
...@@ -1767,4 +1749,4 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con ...@@ -1767,4 +1749,4 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
#endif #endif
} }
} // namespace flash } // namespace flash
\ No newline at end of file
...@@ -646,6 +646,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param ...@@ -646,6 +646,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params &param
inline_utcl2_warmup_dword(k_addr); inline_utcl2_warmup_dword(k_addr);
inline_utcl2_warmup_dword(v_addr); inline_utcl2_warmup_dword(v_addr);
} }
// Keep warmup buffer loads out of the MLS vmcnt schedule below.
flash::wait_all_buffer_data_arrived<true>();
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum // splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
int row_offset_lse; int row_offset_lse;
...@@ -1268,4 +1270,4 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_kernel(Params pa ...@@ -1268,4 +1270,4 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_kernel(Params pa
} }
} }
} }
\ No newline at end of file
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment