Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
...@@ -109,13 +109,13 @@ struct Allreduce { ...@@ -109,13 +109,13 @@ struct Allreduce {
static __device__ inline union_vec2_fp32 run(union_vec2_fp32 x, Operator &op) { static __device__ inline union_vec2_fp32 run(union_vec2_fp32 x, Operator &op) {
union_vec2_fp32 res; union_vec2_fp32 res;
if constexpr (std::is_same<Operator, SumOp<float> >::value) { if constexpr (std::is_same<Operator, SumOp<float> >::value) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
res.f32[0] = __shfl_xor_tmp(x.f32[0], 32); res.f32[0] = __shfl_xor_tmp(x.f32[0], 32);
res.f32[1] = __shfl_xor_tmp(x.f32[1], 32); res.f32[1] = __shfl_xor_tmp(x.f32[1], 32);
x.u64 = hcu_pk_add_f32(x.u64, res.u64); x.u64 = __builtin_hcu_pk_add_f32(x.u64, res.u64);
res.f32[0] = __shfl_swap16(x.f32[0]); // __shfl_xor_tmp(x.f32[0], 16); res.f32[0] = __shfl_swap16(x.f32[0]); // __shfl_xor_tmp(x.f32[0], 16);
res.f32[1] = __shfl_swap16(x.f32[1]); // __shfl_xor_tmp(x.f32[1], 16); res.f32[1] = __shfl_swap16(x.f32[1]); // __shfl_xor_tmp(x.f32[1], 16);
res.u64 = hcu_pk_add_f32(res.u64, x.u64); res.u64 = __builtin_hcu_pk_add_f32(res.u64, x.u64);
#else #else
x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32); x.f32[0] = x.f32[0] + __shfl_xor_tmp(x.f32[0], 32);
x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32); x.f32[1] = x.f32[1] + __shfl_xor_tmp(x.f32[1], 32);
...@@ -141,10 +141,7 @@ struct Allreduce { ...@@ -141,10 +141,7 @@ struct Allreduce {
template<const int kHeadDim, typename T, bool Do_CacheSwizzle=true> template<const int kHeadDim, typename T, bool Do_CacheSwizzle=true>
__device__ __forceinline__ vec4_uint prepare_for_buffer_load(T* ptr) { __device__ __forceinline__ vec4_uint prepare_for_buffer_load(T* ptr) {
vec4_uint res; vec4_uint res;
struct { uint32_t lo, hi; } parts; *(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
*(uint64_t*)&parts = reinterpret_cast<uint64_t>(ptr);
res[0] = __builtin_amdgcn_readfirstlane(parts.lo);
res[1] = __builtin_amdgcn_readfirstlane(parts.hi);
if constexpr (Do_CacheSwizzle) { if constexpr (Do_CacheSwizzle) {
if constexpr (kHeadDim == 128) { if constexpr (kHeadDim == 128) {
res[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride res[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
...@@ -194,7 +191,7 @@ __forceinline__ __device__ void attention_initialize( ...@@ -194,7 +191,7 @@ __forceinline__ __device__ void attention_initialize(
#if defined(__gfx936__) #if defined(__gfx936__)
acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(pk_zero); acc_o[i][min_tile_n * 2 + min_tile_m].u64[0] = __builtin_hcu_mov_b64(pk_zero);
acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(pk_zero); acc_o[i][min_tile_n * 2 + min_tile_m].u64[1] = __builtin_hcu_mov_b64(pk_zero);
#elif defined(__gfx938__) #elif defined(__gfx938__) || defined(__gfx946__)
asm volatile("v_mov_b64 %0, 0x0" asm volatile("v_mov_b64 %0, 0x0"
: "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[0]) : "=v"(acc_o[i][min_tile_n * 2 + min_tile_m].u64[0])
:); :);
...@@ -213,4 +210,27 @@ __forceinline__ __device__ void attention_initialize( ...@@ -213,4 +210,27 @@ __forceinline__ __device__ void attention_initialize(
} }
template<int kHeadDim, int WARP_M, int WARP_N, typename ElementAccum>
__forceinline__ __device__ void fp8_attention_initialize(
ElementAccum scores_max[WARP_M / 16],
ElementAccum scores_sum[WARP_M / 16],
vec4_Accum<ElementAccum> acc_o[kHeadDim / 32][WARP_M / 16][WARP_N / 16]
) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
scores_max[m_idx] = -INFINITY;
scores_sum[m_idx] = 0;
}
#pragma unroll
for (int pv_loop = 0; pv_loop < kHeadDim / 32; ++pv_loop) {
#pragma unroll
for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WARP_N / 16; ++n_idx) {
inline_vgpr4_init_zero(acc_o[pv_loop][m_idx][n_idx]);
}
}
}
}
} // namespace flash } // namespace flash
...@@ -4,28 +4,43 @@ ...@@ -4,28 +4,43 @@
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#include "numeric_types.h" #include "numeric_types.h"
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#define USE_BUFFER_LOAD_DWORDX4 #define USE_BUFFER_LOAD_DWORDX4
// #define USE_BUFFER_LOAD_DWORDX2 // #define USE_BUFFER_LOAD_DWORDX2
#endif #endif
// DTK ds_read_matrix builtins (DS_READ_MATRIX_FORMAT / _TRANS_FORMAT): para1 is LDS base
// typed per element kind — e.g. *_f16 → half*3, *_bf16 / *_u16 / *_i16 → short*3, *_f32 → float*3,
// 4/8-bit and tf32/u32/i32 variants → int*3 (vendor builtin table). template<typename VEC, typename pointerType>
// HIP may use __half for fp16 LDS while builtins expect __fp16*3; use f16 helper below. __forceinline__ __device__ void inline_global_load_dwordx1(VEC& v_data, const int v_offset, const pointerType* s_addr) {
// Probe: FA_PROBE_FAMILY_DS (lds_f16_as3, lds_bf16_as3). const int v_offset_bytes = v_offset * sizeof(pointerType);
template<typename T> asm volatile(
__forceinline__ __device__ __attribute__((address_space(3))) __fp16 * "global_load_dword %0, %1, %2\n"
hcu_ds_read_matrix_f16_lds_base(T *const p) { : "=v"(v_data)
return (__attribute__((address_space(3))) __fp16 *)(p); : "v"(v_offset_bytes), "s"(s_addr)
:);
} }
template<typename T> template<const int shfl_count, bool bypass, class DataType>
__forceinline__ __device__ __attribute__((address_space(3))) short * __forceinline__ __device__ void inline_buffer_load_dwordx1(DataType& v_data, const int v_offset, const vec4_uint global_addr) {
hcu_ds_read_matrix_bf16_lds_base(T *const p) {
return (__attribute__((address_space(3))) short *)(p); int v_offset_bytes = v_offset << shfl_count;
if constexpr (bypass) {
asm volatile(
"buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc\n"
: "=v"(v_data)
: "v"(v_offset_bytes), "s"(global_addr)
:);
} else {
asm volatile(
"buffer_load_dword %0, %1, %2, 0, offen offset:0\n"
: "=v"(v_data)
: "v"(v_offset_bytes), "s"(global_addr)
:);
}
} }
template<class DataType> template<class DataType>
__forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resource) { __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resource) {
int container; int container;
...@@ -34,6 +49,7 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour ...@@ -34,6 +49,7 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour
asm volatile( asm volatile(
"s_nop 4\n\t" "s_nop 4\n\t"
"buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc\n\t" "buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc\n\t"
"s_waitcnt vmcnt(0)\n"
: "=v"(container) : "=v"(container)
: "v"(offset), "s"(buffer_resource) : "v"(offset), "s"(buffer_resource)
); );
...@@ -44,177 +60,99 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour ...@@ -44,177 +60,99 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dword_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n" "buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx2_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dwordx2_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void safe_inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &offset_s, const int &offset_v) { __forceinline__ __device__ void safe_inline_buffer_load_dwordx4_lds(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &offset_s, const int &offset_v) {
int lds_addr_per_wave = __builtin_amdgcn_readfirstlane( int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int __offset_s = offset_s << shfl_count; int __offset_s = offset_s << shfl_count;
int __offset_v = offset_v << shfl_count; int __offset_v = offset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_nop 3\n\t" asm volatile("s_nop 3\n\t"
"s_mov_b32 m0, %1\n\t" "s_mov_b32 m0, %1\n\t"
"s_nop 0\n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds\n" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds\n"
:: "v"(__offset_v), "s"(lds_addr_per_wave), "s"(scalar_rsrc), "s"(__offset_s) :: "v"(__offset_v), "s"(lds_addr_per_wave), "s"(global_addr), "s"(__offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_glc_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_glc_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc slc lds\n" "buffer_load_dword %0, %2, %3 ,offen offset:0 glc slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l1_glc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l1_glc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc lds\n" "buffer_load_dword %0, %2, %3 ,offen offset:0 glc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<class DataType, const int shfl_count=2> template<class DataType, const int shfl_count=2>
__forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l2_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void inline_buffer_load_dword_lds_bypass_l2_slc(DataType *const shared_addr, const vec4_uint global_addr, const int &lds_offset, const int &gvOffset_s, const int &gvOffset_v) {
int ldsAddrPerWave = __builtin_amdgcn_readfirstlane( int ldsAddrPerWave = reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count);
(int)(reinterpret_cast<size_t>(shared_addr) + (lds_offset << shfl_count)));
int offset_s = gvOffset_s << shfl_count; int offset_s = gvOffset_s << shfl_count;
int offset_v = gvOffset_v << shfl_count; int offset_v = gvOffset_v << shfl_count;
vec4_uint scalar_rsrc;
scalar_rsrc[0] = __builtin_amdgcn_readfirstlane(global_addr[0]);
scalar_rsrc[1] = __builtin_amdgcn_readfirstlane(global_addr[1]);
scalar_rsrc[2] = __builtin_amdgcn_readfirstlane(global_addr[2]);
scalar_rsrc[3] = __builtin_amdgcn_readfirstlane(global_addr[3]);
asm volatile("s_mov_b32 m0, %1 \n\t" asm volatile("s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0 slc lds\n" "buffer_load_dword %0, %2, %3 ,offen offset:0 slc lds\n"
:: "v"(offset_v), "s"(ldsAddrPerWave), "s"(scalar_rsrc), "s"(offset_s) :: "v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
} }
template<typename src_type=half_t, typename dst_type=float, const int dword_count=1, const int auxilariy=0> template<typename src_type=half_t, typename dst_type=float, const int dword_count=1, const int auxilariy=0>
__forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
#if defined(__gfx936__) || defined(__gfx938__)
static_assert(dword_count == 1 || dword_count == 2 || dword_count == 4, "unsupported buffer_load_dword LDS width");
// DTK currently accepts the mature asm buffer_load_* -> lds shape more reliably than
// the raw_buffer_load_lds wrapper instantiated through generic LDS pointers.
if constexpr (auxilariy == 0) {
if constexpr (dword_count == 1) {
inline_buffer_load_dword_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else if constexpr (dword_count == 2) {
inline_buffer_load_dwordx2_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else {
inline_buffer_load_dwordx4_lds<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
}
} else if constexpr (auxilariy == 11 && dword_count == 1) {
inline_buffer_load_dword_lds_bypass_glc_slc<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
} else {
constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element);
__builtin_hcu_raw_buffer_load_lds(
rsrc,
ptr,
dword_count * 4,
gvOffset_v * bytes_per_element,
gvOffset_s * bytes_per_element,
0, /* immediate offset, instruction offset */
auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
constexpr int bytes_per_element = sizeof(dst_type); constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element); dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds( __builtin_hcu_raw_buffer_load_lds(
rsrc, rsrc,
ptr, ptr,
...@@ -224,16 +162,12 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const sh ...@@ -224,16 +162,12 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const sh
0, /* immediate offset, instruction offset */ 0, /* immediate offset, instruction offset */
auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */ auxilariy /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
); );
#endif
} }
template<typename src_type=half_t, typename dst_type=float> template<typename src_type=half_t, typename dst_type=float>
__forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) { __forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src_type *const shared_addr, const vec4_uint rsrc, const int &lds_offset, const int gvOffset_s, const int &gvOffset_v) {
#if defined(__gfx936__) || defined(__gfx938__)
inline_buffer_load_dword_lds_bypass_glc_slc<src_type, 2>(shared_addr, rsrc, lds_offset, gvOffset_s, gvOffset_v);
#else
constexpr int bytes_per_element = sizeof(dst_type); constexpr int bytes_per_element = sizeof(dst_type);
auto *ptr = (__attribute__((address_space(3))) int *)(reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset) * bytes_per_element); dst_type *ptr = reinterpret_cast<dst_type*>(shared_addr) + lds_offset;
__builtin_hcu_raw_buffer_load_lds( __builtin_hcu_raw_buffer_load_lds(
rsrc, rsrc,
ptr, ptr,
...@@ -243,7 +177,6 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src ...@@ -243,7 +177,6 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src
0, /* immediate offset, instruction offset */ 0, /* immediate offset, instruction offset */
11 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */ 11 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
); );
#endif
} }
template<class DataType, const int shfl_count> template<class DataType, const int shfl_count>
...@@ -335,6 +268,16 @@ __forceinline__ __device__ void inline_ds_read2_b32_no_wait_bytes(const int &l ...@@ -335,6 +268,16 @@ __forceinline__ __device__ void inline_ds_read2_b32_no_wait_bytes(const int &l
} }
template<typename DataType>
__forceinline__ __device__ void inline_ds_read2_b64(const int lds_offset, DataType& reg_val, const int offset0, const int offset1) {
asm volatile(
"ds_read2_b64 %0, %1, offset0:%2, offset1:%3\n"
: "=v"(reg_val)
: "s"(lds_offset), "B"(offset0), "B"(offset1)
:);
}
template<typename dwordx2> template<typename dwordx2>
__forceinline__ __device__ void inlineasm_fa_ds_read2_b32(float* shared_addr, const int &lds_offset, dwordx2& reg_val, const int offset0, const int offset1) { __forceinline__ __device__ void inlineasm_fa_ds_read2_b32(float* shared_addr, const int &lds_offset, dwordx2& reg_val, const int offset0, const int offset1) {
int lds_addr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4; int lds_addr = reinterpret_cast<size_t>(shared_addr) + lds_offset * 4;
...@@ -364,14 +307,14 @@ template<typename VEC> ...@@ -364,14 +307,14 @@ template<typename VEC>
__forceinline__ __device__ void inlineasm_ds_read_b128(int lds_offset, VEC& data) { __forceinline__ __device__ void inlineasm_ds_read_b128(int lds_offset, VEC& data) {
asm volatile("ds_read_b128 %0, %1\n" asm volatile("ds_read_b128 %0, %1\n"
: "=v"(data) : "=v"(data)
: "s"(lds_offset) : "v"(lds_offset)
:); :);
} }
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inlineasm_ds_write_b128(int lds_offset, VEC& data) { __forceinline__ __device__ void inlineasm_ds_write_b128(int lds_offset, VEC& data) {
asm volatile("ds_write_b128 %0, %1\n" asm volatile("ds_write_b128 %0, %1\n"
:: "s"(lds_offset), "v"(data) :: "v"(lds_offset), "v"(data)
:); :);
} }
...@@ -385,7 +328,7 @@ __forceinline__ __device__ void inline_vgpr_init_zero(VEC &dst, const int idx) ...@@ -385,7 +328,7 @@ __forceinline__ __device__ void inline_vgpr_init_zero(VEC &dst, const int idx)
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) { __forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0" asm ("v_mov_b64 %0, 0x0"
: "=v"(dst) : "=v"(dst)
:); :);
...@@ -396,7 +339,7 @@ __forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) { ...@@ -396,7 +339,7 @@ __forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) {
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) { __forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t" asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t" "v_mov_b64 %1, 0x0\n\t"
: "=v"(dst.u64[0]), "=v"(dst.u64[1]) : "=v"(dst.u64[0]), "=v"(dst.u64[1])
...@@ -413,7 +356,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) { ...@@ -413,7 +356,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) {
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) { __forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t" asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t" "v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t" "v_mov_b64 %2, 0x0\n\t"
...@@ -463,7 +406,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) { ...@@ -463,7 +406,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) {
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) { __forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t" asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t" "v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t" "v_mov_b64 %2, 0x0\n\t"
...@@ -498,7 +441,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) { ...@@ -498,7 +441,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) {
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) { __forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t" asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t" "v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t" "v_mov_b64 %2, 0x0\n\t"
...@@ -514,7 +457,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) { ...@@ -514,7 +457,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) {
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void inline_vgpr4_init_zero_1x2x4(VEC s_reg[1][4]) { __forceinline__ __device__ void inline_vgpr4_init_zero_1x2x4(VEC s_reg[1][4]) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm ("v_mov_b64 %0, 0x0\n\t" asm ("v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t" "v_mov_b64 %1, 0x0\n\t"
"v_mov_b64 %2, 0x0\n\t" "v_mov_b64 %2, 0x0\n\t"
...@@ -570,43 +513,6 @@ inline __HOST_DEVICE__ unsigned short inlineasm_float2bfloat16_ushort_nonan(cons ...@@ -570,43 +513,6 @@ inline __HOST_DEVICE__ unsigned short inlineasm_float2bfloat16_ushort_nonan(cons
// DTK-compatible pk helpers (replace __builtin_hcu_pk_*_f32)
inline __device__ __float2 hcu_pk_add_f32(__float2 a, __float2 b) {
__float2 o;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(o) : "v"(a), "v"(b));
return o;
}
inline __device__ __float2 hcu_pk_mul_f32(__float2 a, __float2 b) {
__float2 o;
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(o) : "v"(a), "v"(b));
return o;
}
inline __device__ __float2 hcu_pk_fma_f32(__float2 x, __float2 m, __float2 a) {
__float2 d;
asm volatile("v_pk_fma_f32 %0, %1, %2, %3" : "=v"(d) : "v"(x), "v"(m), "v"(a));
return d;
}
// DTK requires these control operands to remain compile-time constants.
template<bool Clamp = false, int OModifier = 0>
inline __device__ auto hcu_cvt_pk_f16_f32(float src0, float src1) {
static_assert(OModifier == 0, "Only o_modifier=0 is currently validated in HG DTK migration");
return __builtin_hcu_cvt_pk_f16_f32(0, src0, 0, src1, Clamp, OModifier);
}
template<bool Clamp = false>
inline __device__ auto hcu_cvt_pk_bf16_f32(float src0, float src1) {
return __builtin_hcu_cvt_pk_bf16_f32(0, src0, 0, src1, Clamp);
}
template<int ByteSel>
inline __device__ vec2_fp32 hcu_cvt_pk_f32_fp8(int src0) {
static_assert(ByteSel == 0 || ByteSel == 2, "ByteSel must select the low or high packed fp8 pair");
return __builtin_hcu_cvt_pk_f32_fp8(src0, false, 0, ByteSel);
}
// d = a * b + c // d = a * b + c
inline __device__ __float2 inlineasm_fa_v_pk_fma_f32(__float2 &a, const __float2& b, const __float2& c) { inline __device__ __float2 inlineasm_fa_v_pk_fma_f32(__float2 &a, const __float2& b, const __float2& c) {
__float2 d; __float2 d;
...@@ -637,7 +543,7 @@ inline __device__ void inlineasm_fa_v_pk_mul_f32(__float2 &dst, const __float2 & ...@@ -637,7 +543,7 @@ inline __device__ void inlineasm_fa_v_pk_mul_f32(__float2 &dst, const __float2 &
// c = a + b // c = a + b
inline __device__ void inline_v_pk_add_f32(__float2 &c, const __float2 &a, const __float2& b) { inline __device__ void inline_v_pk_add_f32(__float2 &c, const __float2 &a, const __float2& b) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
asm volatile("v_pk_add_f32 %0, %1, %2 ; inline_v_pk_add_f32" asm volatile("v_pk_add_f32 %0, %1, %2 ; inline_v_pk_add_f32"
: "=v"(c) : "=v"(c)
: "v"(a), "v"(b) : "v"(a), "v"(b)
...@@ -873,8 +779,8 @@ inline __host__ __device__ auto DownCastPair(const vec2_Element<FromType>& sourc ...@@ -873,8 +779,8 @@ inline __host__ __device__ auto DownCastPair(const vec2_Element<FromType>& sourc
template<> template<>
inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<float>& source) { inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = hcu_cvt_pk_f16_f32<false, 0>(source[0], source[1]); auto result = __builtin_hcu_cvt_pk_f16_f32(source[0], source[1], false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result); return *(vec2_Element<half_t>*)(&result);
#else #else
return __builtin_amdgcn_cvt_pkrtz(source[0], source[1]); return __builtin_amdgcn_cvt_pkrtz(source[0], source[1]);
...@@ -883,8 +789,8 @@ inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<f ...@@ -883,8 +789,8 @@ inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<f
template<> template<>
inline __host__ __device__ auto DownCastPair<float, bhalf_t>(const vec2_Element<float>& source) { inline __host__ __device__ auto DownCastPair<float, bhalf_t>(const vec2_Element<float>& source) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = hcu_cvt_pk_bf16_f32<false>(source[0], source[1]); auto result = __builtin_hcu_cvt_pk_bf16_f32(source[0], source[1], false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result); return *(vec2_Element<bhalf_t>*)(&result);
#else #else
vec2_Element<bhalf_t> result; vec2_Element<bhalf_t> result;
...@@ -903,8 +809,8 @@ inline __host__ __device__ auto DownCastPairNoPack(const FromType src0, const Fr ...@@ -903,8 +809,8 @@ inline __host__ __device__ auto DownCastPairNoPack(const FromType src0, const Fr
template<> template<>
inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float src0, const float src1) { inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float src0, const float src1) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = hcu_cvt_pk_f16_f32<false, 0>(src0, src1); auto result = __builtin_hcu_cvt_pk_f16_f32(src0, src1, false/*clamp*/, 0/*o_modifier*/);
return *(vec2_Element<half_t>*)(&result); return *(vec2_Element<half_t>*)(&result);
#else #else
return __builtin_amdgcn_cvt_pkrtz(src0, src1); return __builtin_amdgcn_cvt_pkrtz(src0, src1);
...@@ -913,8 +819,8 @@ inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float sr ...@@ -913,8 +819,8 @@ inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float sr
template<> template<>
inline __host__ __device__ auto DownCastPairNoPack<float, bhalf_t>(const float src0, const float src1) { inline __host__ __device__ auto DownCastPairNoPack<float, bhalf_t>(const float src0, const float src1) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
auto result = hcu_cvt_pk_bf16_f32<false>(src0, src1); auto result = __builtin_hcu_cvt_pk_bf16_f32(src0, src1, false/*clamp*/);
return *(vec2_Element<bhalf_t>*)(&result); return *(vec2_Element<bhalf_t>*)(&result);
#else #else
vec2_Element<bhalf_t> result; vec2_Element<bhalf_t> result;
...@@ -954,7 +860,7 @@ __host__ __device__ float splitkv_upcast_to_f32(const FromType &from_var) { ...@@ -954,7 +860,7 @@ __host__ __device__ float splitkv_upcast_to_f32(const FromType &from_var) {
template<typename output_dtype> template<typename output_dtype>
__forceinline__ __device__ void __builtin_hcu_cvt_pk4_fp8_f32(const vec4_fp32& source, int32_t &container) { __forceinline__ __device__ void __builtin_hcu_cvt_pk4_fp8_f32(const vec4_fp32& source, int32_t &container) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if constexpr (std::is_same<output_dtype, fp8_e4m3>::value) { if constexpr (std::is_same<output_dtype, fp8_e4m3>::value) {
container = __builtin_hcu_cvt_pk_fp8_f32(source[0], source[1], container, false/*op_sel:[0,0,0,0]*/); container = __builtin_hcu_cvt_pk_fp8_f32(source[0], source[1], container, false/*op_sel:[0,0,0,0]*/);
container = __builtin_hcu_cvt_pk_fp8_f32(source[2], source[3], container, true/*op_sel:[0,0,0,1]*/); container = __builtin_hcu_cvt_pk_fp8_f32(source[2], source[3], container, true/*op_sel:[0,0,0,1]*/);
......
...@@ -7,35 +7,196 @@ ...@@ -7,35 +7,196 @@
#define VA_LIMIT_BITS(x) (0xffffffffffff & x) #define VA_LIMIT_BITS(x) (0xffffffffffff & x)
template<int INSTM, int INSTNM, int T, int R>
__forceinline__ __device__ void matrix_load_b16_lds_trans_builtin(size_t lds_addr_warp, vec4_int rsrc, int moffset) { #define MATRIX_LOAD_32X32_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__) #if defined(__gfx938__)
int soffset = lds_addr_warp + 0x80000000; int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
if constexpr (INSTM == 32 && INSTNM == 16) { matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
__builtin_hcu_matrix_load_32x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); VDATA:DST
} else if constexpr (INSTM == 32 && INSTNM == 32) { SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
__builtin_hcu_matrix_load_32x32_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); sgpr[SRSRC+2]: stride
} else if constexpr (INSTM == 64 && INSTNM == 16) { sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
__builtin_hcu_matrix_load_64x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); */
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
#define MATRIX_LOAD_32X32_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x32_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X32_B16_LDS_GFX946(lds_addr_per_wave, srsrc,,);
}
#endif
}
// ======================================================= MLS32x16 ===========================================================
#define MATRIX_LOAD_32X16_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds_trans(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_TRANS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(lds_addr_per_wave, srsrc,,);
} }
(void)moffset;
#endif #endif
} }
template<int INSTM, int INSTNM, int T, int R>
__forceinline__ __device__ void matrix_load_b16_lds_builtin(size_t lds_addr_warp, vec4_int rsrc, int moffset) { #define MATRIX_LOAD_32X16_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_32x16_b16_lds(DataType *shared_addr, vec4_uint srsrc, int &lds_offset, const int offset) {
#if defined(__gfx938__) #if defined(__gfx938__)
int soffset = lds_addr_warp + 0x00000000; int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
/*
if constexpr (INSTM == 32 && INSTNM == 16) { matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
__builtin_hcu_matrix_load_32x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); VDATA:DST
} else if constexpr (INSTM == 32 && INSTNM == 32) { SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
__builtin_hcu_matrix_load_32x32_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); sgpr[SRSRC+2]: stride
} else if constexpr (INSTM == 64 && INSTNM == 16) { sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
__builtin_hcu_matrix_load_64x16_b16(rsrc, (__attribute__((address_space(3))) short*)(soffset), 0, T, R, 0, 0); */
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS(lds_addr_per_wave, srsrc,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
if constexpr (r && t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r, t);
} else if constexpr (r && !t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc,, t);
} else {
MATRIX_LOAD_32X16_B16_LDS_GFX946(lds_addr_per_wave, srsrc,,);
} }
(void)moffset;
#endif #endif
} }
...@@ -62,6 +223,25 @@ __forceinline__ __device__ void matrix_load_b16_lds_builtin(size_t lds_addr_warp ...@@ -62,6 +223,25 @@ __forceinline__ __device__ void matrix_load_b16_lds_builtin(size_t lds_addr_warp
:); \ :); \
} }
#define DS_READ_MATRIX_32X32_B16_GFX946(OFFSET, REG, REG1, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_trans_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X16_B16(OFFSET, REG, TRANS) \ #define DS_READ_MATRIX_32X16_B16(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \ if constexpr (TRANS) { \
asm volatile( \ asm volatile( \
...@@ -141,15 +321,22 @@ __forceinline__ __device__ int inline_min_max(int source) { ...@@ -141,15 +321,22 @@ __forceinline__ __device__ int inline_min_max(int source) {
} }
// ======================================================= def =========================================================== // ======================================================= def ===========================================================
#define YY_USE_MPERMUTE
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void ds_mpermute_kdim_for_mmac(VEC& data) { __forceinline__ __device__ void ds_mpermute_kdim_for_mmac(VEC& data) {
asm volatile("ds_mpermute_dwordx2 %0, %0 offset:6\n":: "v"(data)); asm volatile(
"ds_mpermute_dwordx2 %0, %0 offset:6\n"
: "+v"(data)
: );
} }
template<typename VEC> template<typename VEC>
__forceinline__ __device__ void ds_mpermute_kdim_for_mmac_wait(VEC& data) { __forceinline__ __device__ void ds_mpermute_kdim_for_mmac_wait(VEC& data) {
asm volatile("ds_mpermute_dwordx2 %0, %0 offset:6\n\ts_waitcnt lgkmcnt(0)":: "v"(data)); asm volatile(
"ds_mpermute_dwordx2 %0, %0 offset:6\n\ts_waitcnt lgkmcnt(0)\n"
: "+v"(data)
: );
} }
...@@ -163,7 +350,7 @@ inline __device__ vec4_fp32 mmac_4interleave(const vec4_Element<T> &v1, const ve ...@@ -163,7 +350,7 @@ inline __device__ vec4_fp32 mmac_4interleave(const vec4_Element<T> &v1, const ve
template<> template<>
inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3) inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{ {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__)
return __builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(v1, v2, v3, 1, 0); return __builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(v1, v2, v3, 1, 0);
#else #else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3); return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
...@@ -173,7 +360,7 @@ inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1, ...@@ -173,7 +360,7 @@ inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1,
template<> template<>
inline __device__ vec4_fp32 mmac_4interleave<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3) inline __device__ vec4_fp32 mmac_4interleave<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{ {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__)
return __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(v1, v2, v3, 1, 0); return __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(v1, v2, v3, 1, 0);
#else #else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3); return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
......
#pragma once #pragma once
#include "numeric_types.h"
// DTK: __builtin_hcu_matrix_load_*_b8 第二参为 addrspace(3) char*(本机 clang 报错 short* 与 char* 不匹配);b16 用 short* 见 intrinsic_mls_ds.h。
// 改法、soffset(trans +0x80000000)、调用方式与验证:见仓库根目录 ROCM指令迁移到DTK.md §4。
// Inline asm with "s"(vec4_uint) can lower srsrc to VGPR and fail with invalid operand; builtins keep srsrc in the correct class.
template<int r, int t> #define MATRIX_LOAD_128X16_B8_LDS_TRANS(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
__forceinline__ __device__ void matrix_load_128x16_b8_lds_trans_builtin(size_t lds_addr_warp, vec4_int rsrc, int /*matrix_offset*/) { int soffset = LDSADDR + 0x80000000; \
#if defined(__gfx938__) asm volatile("s_nop 4\n\t" \
int soffset = static_cast<int>(lds_addr_warp) + 0x80000000; "matrix_load_128x16_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
// Third arg must be compile-time constant (same pattern as matrix_load_b16); call sites use matrix_offset==0. :: "s"(SRSRC), "s"(soffset), "n"(MATRIX_OFFSET) \
__builtin_hcu_matrix_load_128x16_b8( :);
rsrc,
(__attribute__((address_space(3))) char*)(soffset),
0,
t,
r,
0,
0);
#endif
}
template<int r, int t, class DataType> template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) { __forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
union union_vec4_uint u; int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
u.v32 = srsrc; if constexpr (r && t) {
size_t lds_addr_warp = reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset); MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r, t);
matrix_load_128x16_b8_lds_trans_builtin<r, t>(lds_addr_warp, u.i32, matrix_offset); } else if constexpr (r && !t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_128X16_B8_LDS_TRANS(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif #endif
} }
...@@ -50,28 +43,25 @@ __forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType ...@@ -50,28 +43,25 @@ __forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType
} }
template<int r, int t> #define MATRIX_LOAD_64x32_B8_LDS_REARRANGE(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
__forceinline__ __device__ void matrix_load_64x32_b8_lds_rearrange_builtin(size_t lds_addr_warp, vec4_int rsrc, int /*matrix_offset*/) { asm volatile("s_nop 4\n\t" \
#if defined(__gfx938__) "matrix_load_64x32_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
int soffset = static_cast<int>(lds_addr_warp); :: "s"(SRSRC), "s"(LDSADDR), "n"(MATRIX_OFFSET) \
__builtin_hcu_matrix_load_64x32_b8( :);
rsrc,
(__attribute__((address_space(3))) char*)(soffset),
0,
t,
r,
0,
0);
#endif
}
template<int r, int t, class DataType> template<int r, int t, class DataType>
__forceinline__ __device__ void inline_matrix_load_64x32_b8_lds_rearrange(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) { __forceinline__ __device__ void inline_matrix_load_64x32_b8_lds_rearrange(DataType *shared_addr, vec4_uint srsrc, int lds_offset, const int matrix_offset) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
union union_vec4_uint u; int lds_addr_per_wave = reinterpret_cast<size_t>(shared_addr) + (lds_offset);
u.v32 = srsrc; if constexpr (r && t) {
size_t lds_addr_warp = reinterpret_cast<size_t>(shared_addr) + static_cast<size_t>(lds_offset); MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r, t);
matrix_load_64x32_b8_lds_rearrange_builtin<r, t>(lds_addr_warp, u.i32, matrix_offset); } else if constexpr (r && !t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset, r,);
} else if constexpr (!r && t) {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,, t);
} else {
MATRIX_LOAD_64x32_B8_LDS_REARRANGE(lds_addr_per_wave, srsrc, matrix_offset,,);
}
#endif #endif
} }
...@@ -119,4 +109,4 @@ template<class T, class AccumType> ...@@ -119,4 +109,4 @@ template<class T, class AccumType>
inline __device__ vec4_fp32 mmac_4interleave_b8(const vec8_Element<T> &v1, const vec8_Element<T> &v2, const vec4_fp32 &v3) inline __device__ vec4_fp32 mmac_4interleave_b8(const vec8_Element<T> &v1, const vec8_Element<T> &v2, const vec4_fp32 &v3)
{ {
return __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(v1, v2, v3, 1, 0); return __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(v1, v2, v3, 1, 0);
} }
\ No newline at end of file
...@@ -60,11 +60,6 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -60,11 +60,6 @@ struct Flash_fwd_kernel_traits : public Base {
static constexpr size_t k_smem_size = (STAGES * (kWaveN / 32) * (kBlockK / 32) * (32 * 34)) * sizeof(Element); static constexpr size_t k_smem_size = (STAGES * (kWaveN / 32) * (kBlockK / 32) * (32 * 34)) * sizeof(Element);
static constexpr size_t v_smem_size = (STAGES * kBlockK * 32/*WARP_K*/) * sizeof(Element); static constexpr size_t v_smem_size = (STAGES * kBlockK * 32/*WARP_K*/) * sizeof(Element);
#if (TARGET == 928)
static constexpr int kSmemSize = std::max(q_smem_size, v_smem_size) + k_smem_size * 2;
#else
static constexpr int kSmemSize = std::max(std::max(q_smem_size, v_smem_size), k_smem_size * 2);
#endif
}; };
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
......
#pragma once
#include "numeric_types.h"
namespace gfx92a {
template<bool Is_Varlen, int kHeadDim, int kBlockK, int WARP_M, int WARP_NUM, int M_MMAC_COUNT, typename Element>
__forceinline__ __device__ void kvcache_prefetch_q_to_vgpr(
Element* q_ptr,
Element* q_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
int warp_id,
int query_seqlen_stride,
int query_ngroup_stride,
int ngroups,
int max_seq_q_offset=0) {
constexpr int elementBytes = sizeof(Element);
// resource regs
auto q_addr = prepare_for_buffer_load<kHeadDim, Element, false>(q_ptr);
if constexpr (Is_Varlen) {
int lane_id = int(threadIdx.x) & 63;
flash::wait_buffer_data_arrived<true>(0);
if constexpr (kHeadDim == 128 and WARP_NUM == 4) {
for (int load = 0; load < M_MMAC_COUNT; ++load) {
int q_row = min(load * 16 + (lane_id >> 2), max_seq_q_offset - 1);
int q_col = warp_id * 32 + (lane_id & 3) * 8;
int q_row_seq = q_row / ngroups;
int q_row_regroup = q_row - q_row_seq * ngroups;
int q_load_offset = q_row_seq * ngroups * query_seqlen_stride + q_row_regroup * query_ngroup_stride + q_col;
int q_lds_write_offset = (load * 4 + warp_id) * 16 * 32;
inline_buffer_load_dwordx4_lds<Element, 1>(q_lds, q_addr, q_lds_write_offset, 0, q_load_offset);
}
flash::wait_buffer_data_arrived<true>(0);
for (int load = 0; load < M_MMAC_COUNT; ++load) {
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
int q_lds_load_offset = (load * 4 + neighbor) * 16 * 32 + (lane_id & 15) * 32 + (lane_id >> 4) * 4;
int q_lds_load_bytes = reinterpret_cast<size_t>(q_lds + q_lds_load_offset);
inline_ds_read2_b64(q_lds_load_bytes, q_reg[neighbor * 2 + load].f32, 0, 4);
}
}
flash::wait_lds_data_arrived<true>(0);
} else {
#pragma unroll
for (int neighbor = 0; neighbor < WARP_NUM; ++neighbor) {
#pragma unroll
for (int load = 0; load < M_MMAC_COUNT; ++load) {
int q_row = min(load * 16 + (lane_id & 15), max_seq_q_offset - 1);
int q_col = neighbor * 32 + (lane_id >> 4) * 4;
int q_row_seq = q_row / ngroups;
int q_row_regroup = q_row - q_row_seq * ngroups;
int q_load_offset = q_row_seq * ngroups * query_seqlen_stride + q_row_regroup * query_ngroup_stride + q_col;
q_reg[neighbor * 2 + load].data[0] = *(double*)(q_ptr + q_load_offset);
q_reg[neighbor * 2 + load].data[1] = *(double*)(q_ptr + q_load_offset + 16);
}
}
}
} else {
if constexpr (kHeadDim == 128 and WARP_NUM == 4) {
// prepare mls resource regs
vec4_uint q_srsrc;
q_srsrc[1] = q_addr[1];
q_srsrc[2] = query_seqlen_stride;
// global offset along seqlen_q
int q_loop = 0;
int q_seq_offset = q_loop * kBlockK;
// global offset along headdim
int q_dim_offset = warp_id * kBlockK;
// global bytes
q_srsrc[0] = q_addr[0] + (q_seq_offset + q_dim_offset ) * elementBytes;
if constexpr (true) {
int nm_filter = inline_min_max<0, 32>(32 - max_seq_q_offset);
q_srsrc[3] = max_seq_q_offset % 32 == 0 ? 0: nm_filter << 8;
}
// compute lds write offset, each warp occupy 32 * 32 * sizeof(f16) = 2KB
int q_lds_write_offset = warp_id * (WARP_M / 32) * (kBlockK / 32) * (32 * 32);
int q_lds_offset_bytes = q_lds_write_offset * elementBytes;
// flash::wait_lds_data_arrived<true>(0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(q_lds, q_srsrc, q_lds_offset_bytes, 0);
// wait q data arrived
flash::wait_buffer_data_arrived<true>(0);
// lds -> vgprs
if constexpr (M_MMAC_COUNT == 1) {
DS_READ_MATRIX_32X16_B16(0 * 32 * 32 * 2, q_reg[0 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(1 * 32 * 32 * 2, q_reg[1 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(2 * 32 * 32 * 2, q_reg[2 * 2].f16, true);
DS_READ_MATRIX_32X16_B16(3 * 32 * 32 * 2, q_reg[3 * 2].f16, true);
} else {
DS_READ_MATRIX_32X32_B16(0 * 32 * 32 * 2, q_reg[0 * 2].f16, q_reg[0 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(1 * 32 * 32 * 2, q_reg[1 * 2].f16, q_reg[1 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(2 * 32 * 32 * 2, q_reg[2 * 2].f16, q_reg[2 * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(3 * 32 * 32 * 2, q_reg[3 * 2].f16, q_reg[3 * 2 + 1].f16, true);
}
flash::wait_lds_data_arrived<true>(0);
}
else {
// TODO
}
}
}
template<int kBlockK, int WARP_N, int prefetchKLevel, typename Element>
__forceinline__ __device__ void kvcache_prefetch_k_to_lds(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_seq_stride,
int max_seq_k_offset=0) {
constexpr int elementBytes = sizeof(Element);
// prepare mls resource regs
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_seq_stride;
// pingpong buffer stage
int stage_id = 0;
// tile id along headdim dimension
int k_loop = 0;
// occupy 4 * 2 * 2 * 32 * 32 * sizeof(f16) = 32 KB, in total
#pragma unroll
for (int prefetch_id = 0; prefetch_id < prefetchKLevel; ++prefetch_id) {
// global bytes along headdim
int k_dim_bytes = (k_loop + prefetch_id) * kBlockK * elementBytes;
// global bytes along seqlen
int k_seq_bytes;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_N + 32 - max_seq_k_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id;
k_seq_bytes = real_mls_warp_id * WARP_N * k_seq_stride * elementBytes;
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_N + 32 - max_seq_k_offset);
k_srsrc[3] = nm_filter << 8;
}
// acquire buffer address
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_dim_bytes + k_seq_bytes);
// compute lds offset / bytes
int k_lds_stage_offset = (warp_id * prefetchKLevel + prefetch_id) * (WARP_N / 32) * (kBlockK / 32) * (32 * 32);
int lds_offset_bytes = k_lds_stage_offset * elementBytes;
inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kBlockN, int WARP_K, int STAGES, int prefetchVLevel, typename Element>
__forceinline__ __device__ void kvcache_prefetch_v_to_lds(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_seq_stride,
int max_seq_kv_offset=0) {
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
constexpr int elementBytes = 2;
// prepare mls resource regs
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_seq_stride;
if constexpr (prefetchVLevel == 2) {
// tile loop
int n_loop = 0;
// ping-ping stage
int stage_id = 0;
#pragma unroll
for (int prefetch_id = 0; prefetch_id < prefetchVLevel; ++prefetch_id) {
// global bytes along headdim dimension
int v_dim_bytes = (n_loop + prefetch_id) * kBlockN * elementBytes;
// global bytes along seq dimension
int v_seq_bytes;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id;
v_seq_bytes = real_mls_warp_id * WARP_K * v_seq_stride * elementBytes;
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_seq_bytes + v_dim_bytes);
// lds bytes
int v_lds_write_offset = (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32);
int v_lds_write_bytes = v_lds_write_offset * elementBytes;
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_lds_write_bytes, 0);
}
} else if (prefetchVLevel == 4) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < prefetchVLevel; ++prefetch_id) {
// global bytes along headdim dimension
int v_dim_bytes = prefetch_id * kBlockN * elementBytes;
// global bytes along seq dimension
int v_seq_bytes;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id;
v_seq_bytes = real_mls_warp_id * WARP_K * v_seq_stride * elementBytes;
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_seq_bytes + v_dim_bytes);
// lds bytes
int v_lds_write_offset = (warp_id * prefetchVLevel + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32);
int v_lds_write_bytes = v_lds_write_offset * elementBytes;
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_lds_write_bytes, 0);
}
}
__builtin_amdgcn_sched_barrier(0);
}
template<int kHeadDim, int kHeadDimV, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int prefetchKLevel, int prefetchVLevel, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_qk_gemm_prefetch_v(
vec4_uint k_addr,
vec4_uint v_addr,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(kHeadDim / kBlockK) * (WARP_M * kBlockK) / (32 * 32) * 2],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int k_seq_stride,
int v_seq_stride,
int max_seq_kv_offset=0) {
static_assert(WARP_M == 32 and WARP_N == 32 and kBlockK == 32 and "To simplify, only WARP_M = WARP_N = kBlockK = 32 is supported!");
static_assert (prefetchKLevel == 4 and "To simplify, only prefetchKLevel = 4 is supported");
constexpr int K_LOAD_REQUESTS = (WARP_N / 32) * (kBlockK / 32);
constexpr int elementBytes = 2;
// alloc k_regs, 32x32 f16 per warp, and thus 16 f16 for each threads
union_vec4_f16x2<Element> k_reg[1 * (WARP_N * kBlockK) / (32 * 32) * 2];
// s_reg initialize
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
s_reg[0][min_tile_n * 2 + min_tile_m].b64[0] = __builtin_hcu_mov_b64(0x0);
s_reg[0][min_tile_n * 2 + min_tile_m].b64[1] = __builtin_hcu_mov_b64(0x0);
}
}
// qk gemm main loop, along kheaddim dimension
for (int k_loop = 0; k_loop < (kHeadDim / kBlockK); k_loop += 1) {
flash::wait_buffer_data_arrived<false>((kHeadDim / kBlockK) - 1 - k_loop);
// lds -> vgprs
int k_lds_load_bytes = reinterpret_cast<size_t>(k_lds) + (warp_id * prefetchKLevel + k_loop) * K_LOAD_REQUESTS * (32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16(k_lds_load_bytes, k_reg[0].f16, k_reg[1].f16, true);
// mmac flow
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_n);
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int q_tile_id = k_loop * 2 + min_tile_m;
s_reg[0][min_tile_n * 2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[q_tile_id].f16x4[min_tile_k],
k_reg[min_tile_n].f16x4[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
// can be simplified as flash::wait_all_warp_arrived()
flash::wait_lds_data_arrived<true>(0);
// prefetch v
// can be rearranged while qk doing mmac
// gfx92a::kvcache_prefetch_v_to_lds<kHeadDimV, kBlockK, kBlockK, 2/*STAGES*/, prefetchVLevel, Element>(v_addr, v_lds, warp_id, v_seq_stride, max_seq_kv_offset);
}
template <int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT, typename DataType>
__forceinline__ __device__ void kvcache_apply_mask(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int max_seqlen_k, const int col_idx_offset_= 0) {
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
if (col_idx >= max_seqlen_k) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
}
template <int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT, bool Is_Varlen, typename DataType>
__forceinline__ __device__ void kvcache_apply_mask_causal(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int ngroups, const int mtp, const int layout) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
int col_idx_limit_right;
if constexpr (Is_Varlen) {
col_idx_limit_right = std::min(max_seqlen_k, (row_idx / ngroups)/*only for layout 1: bshd*/ + max_seqlen_k - (max_seqlen_q / ngroups));
} else {
const int row_in_mtp = layout == 0 ? (row_idx % mtp): (row_idx / ngroups);
col_idx_limit_right = std::min(max_seqlen_k, row_in_mtp + max_seqlen_k - mtp);
}
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx * 4;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void convert_attn_f32_to_f16(union_vec4_fp32 s_reg[1][4], union_vec2_f16x2<Element> p_reg[1][4]) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
p_reg[0][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(s_reg[0][0 * 2 + min_tile_m].f32x2[min_tile_k]);
p_reg[0][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair<float, Element>(s_reg[0][1 * 2 + min_tile_m].f32x2[min_tile_k]);
}
}
}
template<bool prefetchK, int K_LOOP_COUNT, int kBlockN, int kBlockK, int M_WARP_COUNT, int PV_N_WARP_COUNT, int PV_K_WARP_COUNT, int STAGES, int prefetchKLevel, int prefetchVLevel, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void kvcache_pv_gemm_prefetch_k(
vec4_uint v_addr,
vec4_uint k_addr,
Element* v_lds,
Element* k_lds,
union_vec2_f16x2<Element> p_reg[M_WARP_COUNT * PV_K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * (kBlockN / 32)][4],
int warp_id,
int v_seq_stride,
int k_seq_stride,
int max_seq_kv_offset=0) {
constexpr int WARP_K = PV_K_WARP_COUNT * 32;
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockN == PV_N_WARP_COUNT * 32, "Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32");
static_assert (M_WARP_COUNT == 1, "for gfx938, only WARP_M = 32 is supported yet!");
static_assert (PV_N_WARP_COUNT == 1, "for gfx938, only WARP_N = 32 is supported yet!");
static_assert (PV_K_WARP_COUNT == 1, "for gfx938, only WARP_K = 32 is supported yet!");
constexpr int V_LOAD_REQUESTS = (WARP_K * kBlockN) / (32 * 32);
constexpr int elementBytes = 2;
// sync lds usage for reducing max/sum
flash::wait_lds_data_arrived<true>(0); // __syncthreads();
if constexpr (prefetchVLevel == 2) {
// hold v regs
union_vec4_f16x2<Element> v_reg[1 * PV_K_WARP_COUNT * PV_N_WARP_COUNT * 2];
// prepare v resource regs
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_seq_stride;
// pingpong stage
int stage_id = (STAGES == 2) ? 1: 0;
// make p 4-interleave layout for pv gemm
// strange: delete wait, results are wrong even if flash::wait_lds_data_arrived<false>(0);
ds_mpermute_kdim_for_mmac(p_reg[0][0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][1].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][2].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][3].f16x4);
// pv gemm main loop
constexpr int N_LOOP_STEP = (STAGES == 2) ? prefetchVLevel: 1;
constexpr int N_LOOP_START = (STAGES == 2) ? N_LOOP_STEP: 1;
for (int n_loop = N_LOOP_START; n_loop < K_LOOP_COUNT; n_loop += N_LOOP_STEP) {
#pragma unroll
for (int prefetch_id = 0; prefetch_id < prefetchVLevel; ++prefetch_id) {
// global bytes along headdim dimension
int v_dim_bytes = (n_loop + prefetch_id) * kBlockN * elementBytes;
// global bytes along seq dimension
int v_seq_bytes;
if constexpr (true) {
int nm_filter_max = warp_id * WARP_K + 32 - max_seq_kv_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0: warp_id; // can be simplified after gfx938
v_seq_bytes = real_mls_warp_id * WARP_K * v_seq_stride * elementBytes;
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * WARP_K + 32 - max_seq_kv_offset);
v_srsrc[3] = max_seq_kv_offset % kBlockN == 0 ? 0: nm_filter << 8;
v_srsrc[3] += 0x20000;
}
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_seq_bytes + v_dim_bytes);
// lds write bytes
int v_lds_write_offset = (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel + prefetch_id) * (V_LOAD_REQUESTS * 32 * 32);
int v_lds_write_bytes = v_lds_write_offset * elementBytes;
inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_lds_write_bytes, 0);
}
// wait v data stored in lds
if constexpr (N_LOOP_STEP == 2) {
flash::wait_buffer_data_arrived<false>((prefetchVLevel + prefetchVLevel - 1) * V_LOAD_REQUESTS);
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 2) {
flash::wait_buffer_data_arrived<false>(1 * V_LOAD_REQUESTS);
} else if constexpr (N_LOOP_STEP == 1 and STAGES == 1) {
flash::wait_buffer_data_arrived<false>(0);
}
// roll stage
if constexpr (STAGES == 2) { stage_id ^= 1; }
// lds -> vgprs
int v_lds_load_bytes = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel + 0) * (V_LOAD_REQUESTS * 32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16_ALT2(v_lds_load_bytes, v_reg[0].f16, v_reg[1].f16, false);
// pv mmac flow
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_k);
int pv_tile_id = (STAGES == 2) ? n_loop - 2: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
// process second tile of pv gemm
if constexpr (prefetchVLevel == 2) {
flash::wait_buffer_data_arrived<false>(prefetchVLevel * V_LOAD_REQUESTS);
// lds -> vgprs
int v_lds_load_bytes = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16_ALT2(v_lds_load_bytes, v_reg[0].f16, v_reg[1].f16, false);
// pv gemm mmac flow
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_k);
int pv_tile_id = (STAGES == 2) ? n_loop - 1: n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
if constexpr (STAGES == 2) {
int n_loop = K_LOOP_COUNT;
// wait v stored in lds
flash::wait_buffer_data_arrived<false>((prefetchVLevel - 1) * V_LOAD_REQUESTS);
// roll stage
stage_id ^= 1;
// lds -> vgprs
int v_lds_load_bytes = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel) * (V_LOAD_REQUESTS * 32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16_ALT2(v_lds_load_bytes, v_reg[0].f16, v_reg[1].f16, false);
// pv gemm mmac flow
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_k);
int pv_tile_id = n_loop - 2;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
// process second tile of pv gemm
if constexpr (N_LOOP_STEP == 2) {
flash::wait_buffer_data_arrived<false>(0);
// lds -> vgprs
int v_lds_load_bytes = reinterpret_cast<size_t>(v_lds) + (warp_id * STAGES * prefetchVLevel + stage_id * prefetchVLevel + 1/*prefetch_id*/) * (V_LOAD_REQUESTS * 32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16_ALT2(v_lds_load_bytes, v_reg[0].f16, v_reg[1].f16, false);
// pv gemm mmac flow
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_k);
int pv_tile_id = n_loop - 1;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
} else if constexpr (prefetchVLevel == 4) {
bool can_prefetch_k = max_seq_kv_offset > kBlockK;
if constexpr (prefetchK) {
if (can_prefetch_k) {
gfx92a::kvcache_prefetch_k_to_lds<kBlockN, PV_N_WARP_COUNT * 32, prefetchKLevel, Element>(k_addr, k_lds, warp_id, k_seq_stride, max_seq_kv_offset - kBlockK);
}
}
// hold v regs
union_vec4_f16x2<Element> v_reg[1 * PV_K_WARP_COUNT * PV_N_WARP_COUNT * 2];
// prepare v resource regs
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_seq_stride;
// make p 4-interleave layout for pv gemm
// strange: delete wait, results are wrong even if flash::wait_lds_data_arrived<false>(0);
ds_mpermute_kdim_for_mmac(p_reg[0][0].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][1].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][2].f16x4);
ds_mpermute_kdim_for_mmac(p_reg[0][3].f16x4);
// wait v data stored in lds
if constexpr (prefetchK) {
if (can_prefetch_k) {
flash::wait_buffer_data_arrived<false>(prefetchKLevel/*4 for hdim 128*/);
} else {
flash::wait_buffer_data_arrived<false>(0);
}
} else {
flash::wait_buffer_data_arrived<false>(0);
}
// pv gemm main loop
for (int n_loop = 0; n_loop < K_LOOP_COUNT; n_loop += 1) {
// lds -> vgprs
int v_lds_load_bytes = reinterpret_cast<size_t>(v_lds) + (warp_id * prefetchVLevel + n_loop) * (V_LOAD_REQUESTS * 32 * 32) * elementBytes;
DS_READ_MATRIX_32X32_B16_ALT2(v_lds_load_bytes, v_reg[0].f16, v_reg[1].f16, false);
// pv mmac flow
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
flash::wait_lds_data_arrived<false>(2 - 1 - min_tile_k);
int pv_tile_id = n_loop;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32 = flash::mmac_4interleave<Element, ElementAccum>(
p_reg[0][min_tile_k * 2 + min_tile_m].f16x4,
v_reg[min_tile_k].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n * 2 + min_tile_m].f32);
}
}
}
}
}
// sync lds usage
flash::wait_lds_data_arrived<true>(0);
}
template<int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, typename ElementAccum>
__forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds,
int seqlen_q,
int warp_id,
int lane_id) {
if constexpr (K_LOOP_COUNT == 4 and WARP_NUM == 4 and K_WARP_COUNT == 1) {
constexpr int mmacVgprs = 4;
constexpr int tile16x32Vgprs = 64 * mmacVgprs;
constexpr int tile32x32Vgprs = 2 * tile16x32Vgprs;
constexpr int warpVgprs = WARP_NUM * tile32x32Vgprs;
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
#pragma unroll
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int lds_offset = warp_id * warpVgprs + h_idx * tile32x32Vgprs + min_tile_m * tile16x32Vgprs + lane_id * mmacVgprs;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx][min_tile_n * 2 + min_tile_m].f32;
}
}
__syncthreads();
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// lds base
ElementAccum* acc_o_lds_ptr = acc_o_lds + 0 * warpVgprs + warp_id/*h_idx*/ * tile32x32Vgprs + min_tile_m * tile16x32Vgprs + lane_id * mmacVgprs;
// load data of warp0 as accum base
acc_o[0][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds_ptr + 0 * warpVgprs);
// load warp 1, 2, 3
auto neighbor1 = *(union_vec4_fp32*)(acc_o_lds_ptr + 1 * warpVgprs);
auto neighbor2 = *(union_vec4_fp32*)(acc_o_lds_ptr + 2 * warpVgprs);
auto neighbor3 = *(union_vec4_fp32*)(acc_o_lds_ptr + 3 * warpVgprs);
// accumulate acc_o of all warps
#pragma unroll
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id] = __builtin_hcu_pk_add_f32(acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id], neighbor1.u64[vec_id]);
acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id] = __builtin_hcu_pk_add_f32(acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id], neighbor2.u64[vec_id]);
acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id] = __builtin_hcu_pk_add_f32(acc_o[0][min_tile_n * 2 + min_tile_m].u64[vec_id], neighbor3.u64[vec_id]);
}
}
__syncthreads();
}
} else {
// To be inplemented
}
}
template<bool Is_Varlen, bool Split, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_MMAC_COUNT, typename SplitkvAccumType, typename ElementAccum, typename Params>
__forceinline__ __device__ void kvcache_varlen_epilogue_store_output(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT][4],
Params params,
int64_t row_offset_o,
int seqlen_q_limit,
int warp_id,
int lane_id) {
int o_mmac_row = lane_id & 15;
int o_mmac_col = lane_id >> 4;
int o_seq_stride = params.o_row_stride;
SplitkvAccumType* o_ptr = reinterpret_cast<SplitkvAccumType *>(Split ? params.oaccum_ptr: params.o_ptr) + row_offset_o;
if constexpr (K_LOOP_COUNT == 4 and WARP_NUM == 4) {
// each warp output serveral tiles separately
#pragma unroll
for (int k_loop = 0; k_loop < K_LOOP_COUNT; k_loop += WARP_NUM/*1*/) {
int tile_32x32_id = 0/*k_loop*/;
union_vec4_f16x2<SplitkvAccumType> o_data[M_MMAC_COUNT];
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// 2-interleave
o_data[min_tile_m].f16x2[0 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[0], acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[1]);
o_data[min_tile_m].f16x2[1 + 0 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[2], acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[3]);
o_data[min_tile_m].f16x2[0 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[0], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[1]);
o_data[min_tile_m].f16x2[1 + 1 * 2] = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[2], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[3]);
// make 4-interleave
ds_mpermute_kdim_for_mmac(o_data[min_tile_m].f16x4[0]);
ds_mpermute_kdim_for_mmac(o_data[min_tile_m].f16x4[1]);
}
union_vec4_f16x2<SplitkvAccumType> o_dwordx4[M_MMAC_COUNT];
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
flash::wait_lds_data_arrived<false>((M_MMAC_COUNT - 1 - min_tile_m) * 2);
o_dwordx4[min_tile_m].f16[0] = o_data[min_tile_m].f16[0];
o_dwordx4[min_tile_m].f16[1] = o_data[min_tile_m].f16[4];
o_dwordx4[min_tile_m].f16[2] = o_data[min_tile_m].f16[1];
o_dwordx4[min_tile_m].f16[3] = o_data[min_tile_m].f16[5];
o_dwordx4[min_tile_m].f16[4] = o_data[min_tile_m].f16[2];
o_dwordx4[min_tile_m].f16[5] = o_data[min_tile_m].f16[6];
o_dwordx4[min_tile_m].f16[6] = o_data[min_tile_m].f16[3];
o_dwordx4[min_tile_m].f16[7] = o_data[min_tile_m].f16[7];
}
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
// store 4 dwords into global memory
int seqlen_q_idx = o_mmac_row + min_tile_m * 16;
if (seqlen_q_idx < seqlen_q_limit) {
int pv_global_addr;
if constexpr (Is_Varlen) {
int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups;
pv_global_addr = true_seqlen_q * params.ngroups * o_seq_stride + true_group_id * params.o_head_stride + (warp_id + 0) * kBlockK + o_mmac_col * 8;
} else {
pv_global_addr = seqlen_q_idx * o_seq_stride + (warp_id + 0) * kBlockK + o_mmac_col * 8;
}
*(vec4_fp32*)(o_ptr + pv_global_addr) = o_dwordx4[min_tile_m].f32;
}
}
}
} else {
// To be inplemented
}
}
} // end of namespace gfx92a
\ No newline at end of file
#pragma once #pragma once
#include "numeric_types.h" #include "numeric_types.h"
#include "intrinsic.h"
__forceinline__ __device__ float fp8_kvcache_attention_sink_load(const void *s_aux_ptr, int s_aux_type, int head_idx) {
if (s_aux_type == 1) {
return reinterpret_cast<const float *>(s_aux_ptr)[head_idx];
} else if (s_aux_type == 2) {
return UpCast<half_t, float>(reinterpret_cast<const half_t *>(s_aux_ptr)[head_idx]);
} else {
return UpCast<BFloat16, float>(reinterpret_cast<const BFloat16 *>(s_aux_ptr)[head_idx]);
}
}
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_apply_attention_sink_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_max[M_WARP_COUNT],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
const void *s_aux_ptr,
int s_aux_type,
int bidh,
int reduced_num_heads,
int ngroups,
int m_block,
int kBlockM,
int lane_id,
ElementAccum scale_softmax) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row = m_block * kBlockM + mi * 32 + (lane_id & 15) + min_tile_m * 16;
const int group_id = row % ngroups;
const int sink_head = bidh * ngroups + group_id;
const ElementAccum sink_value = fp8_kvcache_attention_sink_load(s_aux_ptr, s_aux_type, sink_head);
const ElementAccum old_scaled_max = scores_max[mi].f32[min_tile_m] * scale_softmax;
const ElementAccum new_scaled_max = max(old_scaled_max, sink_value);
const ElementAccum old_rescale = __expf(old_scaled_max - new_scaled_max);
scores_sum[mi].f32[min_tile_m] = scores_sum[mi].f32[min_tile_m] * old_rescale + __expf(sink_value - new_scaled_max);
scores_max[mi].f32[min_tile_m] = new_scaled_max / scale_softmax;
__float2 old_rescale_pair = {old_rescale, old_rescale};
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int mmac_id = min_tile_n * 2 + min_tile_m;
const int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#pragma unroll
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] =
__builtin_hcu_pk_mul_f32(acc_o[tile_32x32_id][mmac_id].u64[vec_id], old_rescale_pair);
}
}
}
}
}
}
}
template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Interleave2, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT> template<typename Params, int kHeadDimV, int kHeadDimVSplit, bool Interleave2, bool Split, typename SplitkvAccumType, typename ElementAccum, int kBlockM, int kBlockK, int WARP_NUM, int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT>
__forceinline__ __device__ void kvcache_epilogue_store_output_gfx938( __forceinline__ __device__ void kvcache_epilogue_store_output_gfx938(
...@@ -20,8 +79,9 @@ __forceinline__ __device__ void kvcache_epilogue_store_output_gfx938( ...@@ -20,8 +79,9 @@ __forceinline__ __device__ void kvcache_epilogue_store_output_gfx938(
: reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o; : reinterpret_cast<SplitkvAccumType *>(params.o_ptr) + row_offset_o;
int pv_lane_seq_idx = lane_id & 15; int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4; int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128 // Specialized optimization for headdim 128. Dim256 is split into two
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1); // 128-column stores so it can use the same layout per split.
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and K_LOOP_COUNT == WARP_NUM);
if constexpr (not OPT_FOR_HDIM128) { if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return; if (warp_id > 0) return;
} }
...@@ -90,8 +150,9 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938( ...@@ -90,8 +150,9 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
auto gO = prepare_for_buffer_load<kHeadDimV, SplitkvAccumType, false/*USE_CACHE_SWIZZLE*/>(o_ptr); auto gO = prepare_for_buffer_load<kHeadDimV, SplitkvAccumType, false/*USE_CACHE_SWIZZLE*/>(o_ptr);
int pv_lane_seq_idx = lane_id & 15; int pv_lane_seq_idx = lane_id & 15;
int pv_lane_head_dim_idx = lane_id >> 4; int pv_lane_head_dim_idx = lane_id >> 4;
// Specialized optimizatio for headdim 128 // Specialized optimization for headdim 128. Dim256 is split into two
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1); // 128-column stores so it can use the same layout per split.
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and K_LOOP_COUNT == WARP_NUM);
if constexpr (not OPT_FOR_HDIM128) { if constexpr (not OPT_FOR_HDIM128) {
if (warp_id > 0) return; if (warp_id > 0) return;
} }
...@@ -123,4 +184,40 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938( ...@@ -123,4 +184,40 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
} }
} }
} }
} }
\ No newline at end of file
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention epilogue helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int K_LOOP_COUNT, int M_WARP_COUNT, int K_WARP_COUNT, int M_MMAC_COUNT, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_epilogue_rescale_acco_gfx938(
vec4_Accum<ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
vec2_Accum<ElementAccum> scores_sum[M_WARP_COUNT],
ElementAccum v_descale) {
#pragma unroll
for (int pv_n_loop = 0; pv_n_loop < K_LOOP_COUNT; ++pv_n_loop) {
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
#pragma unroll
for (int ni = 0; ni < K_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
ElementAccum sum = scores_sum[mi].f32[min_tile_m];
ElementAccum inv_sum = (sum == 0.f || sum != sum) ? v_descale : v_descale / sum;
__float2 scale_pair = {inv_sum, inv_sum};
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#pragma unroll
for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] =
__builtin_hcu_pk_mul_f32(acc_o[tile_32x32_id][mmac_id].u64[vec_id], scale_pair);
}
}
}
}
}
}
}
...@@ -42,6 +42,13 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938( ...@@ -42,6 +42,13 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
constexpr int N_LOOP_END = 0; constexpr int N_LOOP_END = 0;
for (int n_loop = N_LOOP_START; n_loop >= N_LOOP_END; n_loop -= N_LOOP_STEP) { for (int n_loop = N_LOOP_START; n_loop >= N_LOOP_END; n_loop -= N_LOOP_STEP) {
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac_wait(p_reg[0][0 * 2 + 0].f16x4);
ds_mpermute_kdim_for_mmac_wait(p_reg[0][0 * 2 + 1].f16x4);
ds_mpermute_kdim_for_mmac_wait(p_reg[0][1 * 2 + 0].f16x4);
ds_mpermute_kdim_for_mmac_wait(p_reg[0][1 * 2 + 1].f16x4);
#endif
#pragma unroll #pragma unroll
for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) { for (int prefetch_id = 0; prefetch_id < N_LOOP_STEP; ++prefetch_id) {
...@@ -66,10 +73,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938( ...@@ -66,10 +73,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset; // v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + v_mls_lds_warp_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "intrinsic.h" #include "intrinsic.h"
#include "fwd/utils.h" #include "fwd/utils.h"
#include "intrinsic_mls_ds.h" #include "intrinsic_mls_ds.h"
#include "intrinsic_mls_ds_b8.h"
template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES> template<int kHeadDim, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int WARP_K, int stage_id, int WARP_NUM, typename Element, int STAGES>
...@@ -47,10 +48,376 @@ __forceinline__ __device__ void kvcache_prefetch_v_to_lds_gfx938( ...@@ -47,10 +48,376 @@ __forceinline__ __device__ void kvcache_prefetch_v_to_lds_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset; // v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset); *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + v_mls_loop_global_offset + v_mls_warp_global_offset);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
union union_vec4_uint v_rsrc_bits; inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, v_mls_lds_warp_offset, 0);
v_rsrc_bits.v32 = v_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(v_lds) + v_mls_lds_warp_offset;
matrix_load_b16_lds_builtin<32, 32, 1, 0>(lds_addr_warp, v_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention PV helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_k_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_row_stride,
int max_seq_k_offset);
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
constexpr int PREFETCH = 2;
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 0;
constexpr int k_loop = K_LOOP_COUNT - 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
}
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element, int load_id>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_one_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset) {
static_assert(K_LOOP_COUNT == 2);
static_assert(load_id == 0 || load_id == 1);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
constexpr int stage_id = 0;
constexpr int k_loop = K_LOOP_COUNT - 1;
const int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(Element);
const int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(Element);
const int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
const int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
const int warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(Element);
const int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
template<bool PrefetchK, int K_LOOP_COUNT, int kBlockK, int kBlockN, int M_WARP_COUNT, int K_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT, typename V_Element, typename P_Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_pv_gemm_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint& k_addr,
V_Element* v_lds,
V_Element* k_lds,
union_vec2_f16x2<P_Element> p_reg[M_WARP_COUNT * K_WARP_COUNT][4],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_v_offset,
int64_t k_addr_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
constexpr int PREFETCH = 2;
flash::wait_lds_data_arrived<true/*sync*/>(0);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 1;
#pragma unroll
for (int k_loop = K_LOOP_COUNT - 1 - PREFETCH; k_loop >= 1; k_loop -= PREFETCH) {
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
flash::wait_buffer_data_arrived<false/*sync*/>(PREFETCH);
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
vec2_fp32 v_f32x2[4];
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
flash::wait_buffer_data_arrived<false/*sync*/>(0);
constexpr bool PrefetchKInPV = PrefetchK && K_LOOP_COUNT == 2;
{
constexpr int k_loop = 1 - PREFETCH;
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
if constexpr (PrefetchKInPV) {
if (load_id == 0 && tile32x32_id == 1) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
vec2_fp32 v_f32x2[4];
v_f32x2[0] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], false/*word_sel*/);
v_f32x2[1] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 0], true/*word_sel*/);
v_f32x2[2] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], false/*word_sel*/);
v_f32x2[3] = __builtin_hcu_cvt_pk_f32_fp8(v_regs[tile32x32_id].i32[min_tile_dim * 2 + 1], true/*word_sel*/);
union_vec4_f16x2<P_Element> v_f16x8;
v_f16x8.f16x2[0] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[0][0], v_f32x2[0][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[1] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[1][0], v_f32x2[1][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[2] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[2][0], v_f32x2[2][1], false/*clamp*/, 0/*o_modifier*/);
v_f16x8.f16x2[3] = __builtin_hcu_cvt_pk_f16_f32(v_f32x2[3][0], v_f32x2[3][1], false/*clamp*/, 0/*o_modifier*/);
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int mmac_id = 0; mmac_id < 2; ++mmac_id) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave<P_Element, ElementAccum>(
p_reg[0][mmac_id * 2 + min_tile_m].f16x4,
v_f16x8.f16x4[mmac_id],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (PrefetchK && !PrefetchKInPV) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
template<bool PrefetchK, int K_LOOP_COUNT, int kBlockK, int kBlockN, int M_WARP_COUNT, int K_WARP_COUNT, int WARP_NUM, int M_MMAC_COUNT, typename V_Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938(
vec4_uint v_addr,
vec4_uint& k_addr,
V_Element* v_lds,
V_Element* k_lds,
union_vec32_fp8 p_reg[M_MMAC_COUNT],
vec4_Accum<ElementAccum> pv_reg[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_v_offset,
int64_t k_addr_offset) {
static_assert(K_LOOP_COUNT % 2 == 0);
static_assert(M_WARP_COUNT == 1);
static_assert(K_WARP_COUNT == 2);
constexpr int PREFETCH = 2;
flash::wait_lds_data_arrived<true/*sync*/>(0);
vec4_uint v_srsrc;
v_srsrc[1] = v_addr[1];
v_srsrc[2] = v_row_stride;
int stage_id = 1;
#pragma unroll
for (int k_loop = K_LOOP_COUNT - 1 - PREFETCH; k_loop >= 1; k_loop -= PREFETCH) {
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
int warp_lds_write_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
int warp_global_bytes;
int v_loop_global_bytes = (k_loop - load_id) * 64 * sizeof(V_Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_v_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * v_row_stride * sizeof(V_Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_v_offset);
v_srsrc[3] = (nm_filter << 8) + 0x20000;
*(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_addr + warp_global_bytes + v_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(v_lds, v_srsrc, warp_lds_write_bytes, 0);
}
flash::wait_buffer_data_arrived<false/*sync*/>(PREFETCH);
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave_b8<int8_t, ElementAccum>(
p_reg[min_tile_m].i8x8[0],
v_regs[tile32x32_id].i8x8[min_tile_dim],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
flash::wait_buffer_data_arrived<false/*sync*/>(0);
constexpr bool PrefetchKInPV = PrefetchK && K_LOOP_COUNT == 2;
{
constexpr int k_loop = 1 - PREFETCH;
stage_id ^= 1;
#pragma unroll
for (int load_id = 0; load_id < PREFETCH; ++load_id) {
union_vec16_fp8 v_regs[2];
int lds_load_bytes = stage_id * 16384 + (WARP_NUM * load_id + warp_id) * 32 * 64 * sizeof(V_Element);
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes, v_regs[0].i32x4, false/*transpose*/)
DS_READ_MATRIX_32x32_B8_ALT2(lds_load_bytes + 32, v_regs[1].i32x4, false/*transpose*/)
int k_loop_inner = k_loop - load_id + PREFETCH;
#pragma unroll
for (int tile32x32_id = 0; tile32x32_id < 2; ++tile32x32_id) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - tile32x32_id);
if constexpr (PrefetchKInPV) {
if (load_id == 0 && tile32x32_id == 1) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 2");
#pragma unroll
for (int min_tile_dim = 0; min_tile_dim < 2; ++min_tile_dim) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32 =
mmac_4interleave_b8<int8_t, ElementAccum>(
p_reg[min_tile_m].i8x8[0],
v_regs[tile32x32_id].i8x8[min_tile_dim],
pv_reg[k_loop_inner * 2 + tile32x32_id][min_tile_dim * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
if constexpr (PrefetchK && !PrefetchKInPV) {
*(int64_t*)&k_addr += k_addr_offset;
fp8_kvcache_prefetch_k_gfx938<WARP_NUM, V_Element>(k_addr, k_lds, warp_id, k_row_stride, max_seq_v_offset - kBlockN);
}
flash::wait_lds_data_arrived<true/*sync*/>(0);
}
template <int M_MMAC_COUNT, typename Element, typename ElementAccum>
inline __device__ void fp8_kvcache_cvt_f32_to_fp8_gfx938(
union_vec32_fp8 p_reg[M_MMAC_COUNT],
vec4_Accum<ElementAccum> s_reg[1][4]) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[0][0 * 2 + min_tile_m].f32, p_reg[min_tile_m].i32[0]);
__builtin_hcu_cvt_pk4_fp8_f32<Element>(s_reg[0][1 * 2 + min_tile_m].f32, p_reg[min_tile_m].i32[1]);
}
}
...@@ -73,10 +73,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938( ...@@ -73,10 +73,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset; // k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset); *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/; int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
} }
// 等待 MLS 数据回来 // 等待 MLS 数据回来
...@@ -272,3 +269,4 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938( ...@@ -272,3 +269,4 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
} }
} // qk_gemm } // qk_gemm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#include "static_switch.h" #include "static_switch.h"
#include "kvcache_pv_gemm_utils_gfx938.h" #include "kvcache_pv_gemm_utils_gfx938.h"
#include "intrinsic_mls_ds_b8.h"
template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT> template<int kHeadDim, int kBlockM, int kBlockK, int WARP_M, int WARP_NUM, typename Element, int STAGES, int M_MMAC_COUNT>
...@@ -40,10 +41,7 @@ __forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_gfx938( ...@@ -40,10 +41,7 @@ __forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_gfx938(
} }
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/; int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
flash::wait_lds_data_arrived<true>(0); flash::wait_lds_data_arrived<true>(0);
union union_vec4_uint q_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(q_lds, q_srsrc, lds_offset_bytes, 0);
q_rsrc_bits.v32 = q_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(q_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, q_rsrc_bits.i32, 0);
flash::wait_buffer_data_arrived<true>(0); flash::wait_buffer_data_arrived<true>(0);
...@@ -199,11 +197,175 @@ __forceinline__ __device__ void kvcache_prefetch_k_to_lds_gfx938( ...@@ -199,11 +197,175 @@ __forceinline__ __device__ void kvcache_prefetch_k_to_lds_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset; // k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset); *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + k_mls_loop_global_offset + k_mls_warp_global_offset);
int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/; int lds_offset_bytes = k_lds_stage_offset * 2/*half -> bytes*/;
union union_vec4_uint k_rsrc_bits; inline_matrix_load_32x32_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset_bytes, 0);
k_rsrc_bits.v32 = k_srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(k_lds) + lds_offset_bytes;
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, k_rsrc_bits.i32, 0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARP_NUM, typename Element>
__forceinline__ __device__ void fp8_kvcache_prefetch_k_gfx938(
vec4_uint k_addr,
Element* k_lds,
int warp_id,
int k_row_stride,
int max_seq_k_offset) {
int stage_id = 0;
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
constexpr int k_loop = 0;
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset);
k_srsrc[3] = (nm_filter << 8) + 0x40000;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
__builtin_amdgcn_sched_barrier(0);
}
template<int K_LOOP_COUNT, int kBlockK, int WARP_NUM, typename Element, int load_id>
__forceinline__ __device__ void fp8_kvcache_prefetch_v_one_gfx938(
vec4_uint v_addr,
Element* v_lds,
int warp_id,
int v_row_stride,
int max_seq_v_offset);
template<bool PrefetchVInQK, int K_LOOP_COUNT, int kHeadDim, int kBlockK, int WARP_M, int WARP_N, int WARP_NUM, int M_MMAC_COUNT, typename Element, typename ElementAccum>
__forceinline__ __device__ void fp8_kvcache_qk_gemm_gfx938(
vec4_uint k_addr,
vec4_uint v_addr,
Element* k_lds,
Element* v_lds,
union_vec16_fp8 q_reg[M_MMAC_COUNT][kHeadDim / 64],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 32) * (WARP_N / 32)][4],
int warp_id,
int k_row_stride,
int v_row_stride,
int max_seq_k_offset = 0) {
static_assert(!PrefetchVInQK || (kHeadDim == 128 && K_LOOP_COUNT == 2));
int stage_id = 0;
vec4_uint k_srsrc;
k_srsrc[1] = k_addr[1];
k_srsrc[2] = k_row_stride;
#pragma unroll
for (int i = 0; i < (WARP_N / WARP_N) * (WARP_M / 32); ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
asm volatile(
"v_mov_b64 %0, 0x0\n\t"
"v_mov_b64 %1, 0x0\n\t"
: "=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[0]),
"=v"(s_reg[i][min_tile_n * 2 + min_tile_m].u64[1])
:);
}
}
}
stage_id ^= 1;
#pragma unroll
for (int k_loop = 1; k_loop < kHeadDim / 64; ++k_loop) {
int warp_lds_write_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
int warp_global_bytes;
int k_loop_global_bytes = k_loop * 64 * sizeof(Element);
int nm_filter_max = warp_id * 32 + 32 - max_seq_k_offset;
int real_mls_warp_id = nm_filter_max >= 32 ? 0 : warp_id;
warp_global_bytes = real_mls_warp_id * 32 * k_row_stride * sizeof(Element);
int nm_filter = inline_min_max<0, 32>(real_mls_warp_id * 32 + 32 - max_seq_k_offset);
k_srsrc[3] = (nm_filter << 8) + 0x40000;
*(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_addr + warp_global_bytes + k_loop_global_bytes);
inline_matrix_load_64x32_b8_lds_rearrange<0, 1>(k_lds, k_srsrc, warp_lds_write_bytes, 0);
flash::wait_buffer_data_arrived<false/*sync*/>(1);
stage_id ^= 1;
union_vec16_fp8 k_regs[WARP_N / 16];
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
if constexpr (PrefetchVInQK) {
if (min_tile_n == 1) {
fp8_kvcache_prefetch_v_one_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element, 0>(
v_addr, v_lds, warp_id, v_row_stride, max_seq_k_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 =
mmac_4interleave_b8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
{
constexpr int k_loop = kHeadDim / 64;
if constexpr (PrefetchVInQK) {
flash::wait_buffer_data_arrived<false/*sync*/>(1);
} else {
flash::wait_buffer_data_arrived<false/*sync*/>(0);
}
stage_id ^= 1;
union_vec16_fp8 k_regs[WARP_N / 16];
int lds_load_bytes = (stage_id * WARP_NUM + warp_id) * 32 * 64 * sizeof(Element);
DS_READ_MATRIX_64x16_B8(lds_load_bytes, k_regs[0].i32x4, true/*transpose*/)
DS_READ_MATRIX_64x16_B8(lds_load_bytes + 1024, k_regs[1].i32x4, true/*transpose*/)
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
flash::wait_lds_data_arrived<false/*sync*/>(1 - min_tile_n);
if constexpr (PrefetchVInQK) {
if (min_tile_n == 1) {
fp8_kvcache_prefetch_v_one_gfx938<K_LOOP_COUNT, kBlockK, WARP_NUM, Element, 1>(
v_addr, v_lds, warp_id, v_row_stride, max_seq_k_offset);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 1");
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
s_reg[0][min_tile_n * 2 + min_tile_m].f32 =
mmac_4interleave_b8<int8_t, ElementAccum>(
q_reg[min_tile_m][k_loop - 1].i8x8[min_tile_k],
k_regs[min_tile_n].i8x8[min_tile_k],
s_reg[0][min_tile_n * 2 + min_tile_m].f32);
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_setprio 0");
}
}
}
...@@ -65,6 +65,44 @@ inline __device__ void kvcache_apply_mask_causal_gfx938(DataType tensor[M_WARP_C ...@@ -65,6 +65,44 @@ inline __device__ void kvcache_apply_mask_causal_gfx938(DataType tensor[M_WARP_C
} }
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_local_causal_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q,
const int ngroups, const int window_size_left, const int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int row_idx_offset = row_idx_offset_ + (lane_id & 15);
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 4;
#pragma unroll
for (int mi = 0; mi < M_WARP_COUNT; ++mi) {
const int row_idx_base = row_idx_offset + mi * 32;
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
const int row_idx = row_idx_base + min_tile_m * 16;
const int logical_row = row_idx / ngroups;
const int logical_q = max_seqlen_q / ngroups;
const int col_idx_limit_left = max(0, logical_row + max_seqlen_k - logical_q - window_size_left);
const int col_idx_limit_right = min(max_seqlen_k, logical_row + max_seqlen_k - logical_q + window_size_right);
#pragma unroll
for (int ni = 0; ni < N_WARP_COUNT; ++ni) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n * 16;
#pragma unroll
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
const int col_idx = col_idx_base + vec_idx;
tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx] =
(col_idx < col_idx_limit_left || col_idx > col_idx_limit_right)
? -INFINITY
: tensor[mi + ni * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT> template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_, inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4], const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset_,
...@@ -96,3 +134,26 @@ inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WA ...@@ -96,3 +134,26 @@ inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WA
} }
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention score helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType, int M_WARP_COUNT, int N_WARP_COUNT, int M_MMAC_COUNT>
inline __device__ void fp8_kvcache_apply_descale_gfx938(
DataType tensor[M_WARP_COUNT * N_WARP_COUNT][4],
const __float2 qk_descale) {
#pragma unroll
for (int i = 0; i < M_WARP_COUNT * N_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
tensor[i][min_tile_n * 2 + min_tile_m].u64[0] =
__builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[0], qk_descale);
tensor[i][min_tile_n * 2 + min_tile_m].u64[1] =
__builtin_hcu_pk_mul_f32(tensor[i][min_tile_n * 2 + min_tile_m].u64[1], qk_descale);
}
}
}
}
...@@ -32,7 +32,7 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -32,7 +32,7 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
// #################################################################################################################################################### // ####################################################################################################################################################
// 4 个 wave 共同参与 acc_o 在 LDS 中的相加 // 4 个 wave 共同参与 acc_o 在 LDS 中的相加
// 判断当前架构是否支持 pk_f32 指令 // 判断当前架构是否支持 pk_f32 指令
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
constexpr bool SUPPORT_PK_F32 = true; constexpr bool SUPPORT_PK_F32 = true;
#else #else
constexpr bool SUPPORT_PK_F32 = false; constexpr bool SUPPORT_PK_F32 = false;
...@@ -73,13 +73,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -73,13 +73,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -92,13 +92,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -92,13 +92,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -111,13 +111,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -111,13 +111,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -130,13 +130,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -130,13 +130,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -149,13 +149,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -149,13 +149,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -168,13 +168,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -168,13 +168,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
} }
// asm volatile("s_nop 8\n"); // asm volatile("s_nop 8\n");
{ {
...@@ -187,13 +187,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -187,13 +187,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(6)\n"); asm volatile("s_waitcnt lgkmcnt(6)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[0].u64);
asm volatile("s_waitcnt lgkmcnt(5)\n"); asm volatile("s_waitcnt lgkmcnt(5)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[0].u64);
asm volatile("s_waitcnt lgkmcnt(4)\n"); asm volatile("s_waitcnt lgkmcnt(4)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[0].u64);
} }
// 先写一部分数据到 lds // 先写一部分数据到 lds
for (int loop_id = 0; loop_id < 7; ++loop_id) { for (int loop_id = 0; loop_id < 7; ++loop_id) {
...@@ -206,13 +206,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -206,13 +206,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm volatile("s_waitcnt lgkmcnt(2)\n"); asm volatile("s_waitcnt lgkmcnt(2)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave1[1].u64);
asm volatile("s_waitcnt lgkmcnt(1)\n"); asm volatile("s_waitcnt lgkmcnt(1)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave2[1].u64);
asm volatile("s_waitcnt lgkmcnt(0)\n"); asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_tmp_wave0[loop_id].u64 = hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64); acc_tmp_wave0[loop_id].u64 = __builtin_hcu_pk_add_f32(acc_tmp_wave0[loop_id].u64, acc_tmp_wave3[1].u64);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
acc_o_lds[lds_offset[loop_id]] = acc_tmp_wave0[loop_id].f32[0]; acc_o_lds[lds_offset[loop_id]] = acc_tmp_wave0[loop_id].f32[0];
acc_o_lds[lds_offset[loop_id] + 16] = acc_tmp_wave0[loop_id].f32[1]; acc_o_lds[lds_offset[loop_id] + 16] = acc_tmp_wave0[loop_id].f32[1];
...@@ -233,26 +233,23 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -233,26 +233,23 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
union_vec2_fp32 acc_tmp; union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 0*16 + (lane_id>>4)*4 + WARP_ID; int lds_offset0 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 0*16 + (lane_id>>4)*4 + WARP_ID;
int lds_offset1 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 1*16 + (lane_id>>4)*4 + WARP_ID; int lds_offset1 = min_tile_m*__kHeadDim + q_seq_idx*2*__kHeadDim + h_idx*kBlockK + k_idx*32 + 1*16 + (lane_id>>4)*4 + WARP_ID;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0, acc_tmp.u64, 0, 16); acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
// acc_tmp.f32[0] = acc_o_lds[lds_offset0]; // acc_tmp.f32[0] = acc_o_lds[lds_offset0];
// acc_tmp.f32[1] = acc_o_lds[lds_offset1]; // acc_tmp.f32[1] = acc_o_lds[lds_offset1];
union_vec2_fp32 acc_tmp_wave1; union_vec2_fp32 acc_tmp_wave1;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave1.u64, 0, 16); acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
// acc_tmp_wave1.f32[0] = acc_o_lds[lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave1.f32[0] = acc_o_lds[lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave1.f32[1] = acc_o_lds[lds_offset1 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave1.f32[1] = acc_o_lds[lds_offset1 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave1.f32[0]; acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1]; acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2; union_vec2_fp32 acc_tmp_wave2;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave2.u64, 0, 16); acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
// acc_tmp_wave2.f32[0] = acc_o_lds[lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave2.f32[0] = acc_o_lds[lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave2.f32[1] = acc_o_lds[lds_offset1 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave2.f32[1] = acc_o_lds[lds_offset1 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave2.f32[0]; acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1]; acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3; union_vec2_fp32 acc_tmp_wave3;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, acc_tmp_wave3.u64, 0, 16); acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
// acc_tmp_wave3.f32[0] = acc_o_lds[lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave3.f32[0] = acc_o_lds[lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave3.f32[1] = acc_o_lds[lds_offset1 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim]; // acc_tmp_wave3.f32[1] = acc_o_lds[lds_offset1 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp.f32[0] += acc_tmp_wave3.f32[0]; acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
...@@ -295,4 +292,4 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce( ...@@ -295,4 +292,4 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
} }
} }
} }
} }
\ No newline at end of file
...@@ -69,7 +69,7 @@ __forceinline__ __device__ void int8_kvcache_qk_gemm_prefetch_v_3stage( ...@@ -69,7 +69,7 @@ __forceinline__ __device__ void int8_kvcache_qk_gemm_prefetch_v_3stage(
auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_q, 2>; auto BUFFER_LOAD_FUNC = &inline_buffer_load_dword_lds<Element_q, 2>;
// load 指令发下去之后, 先做一些初始化运算 // load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) { if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg); inline_vgpr4_init_zero_1x2x4(s_reg);
} else { } else {
......
...@@ -228,7 +228,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA ...@@ -228,7 +228,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
#pragma unroll #pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) { for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32 // 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary[m_idx*2].u64 = 0x0; // 可以更狠一点, 直接初始化成第一个 additem_pair, 但是貌似容易导致编译器出问题, 影响不大, 可以不加 summary[m_idx*2].u64 = 0x0; // 可以更狠一点, 直接初始化成第一个 additem_pair, 但是貌似容易导致编译器出问题, 影响不大, 可以不加
#pragma unroll #pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) { for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
...@@ -236,7 +236,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA ...@@ -236,7 +236,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { for(int vec_idx=0; vec_idx<4; vec_idx++) {
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary[m_idx*2].u64 = hcu_pk_add_f32( summary[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary[m_idx*2].u64, summary[m_idx*2].u64,
additem_pair additem_pair
); );
...@@ -262,7 +262,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA ...@@ -262,7 +262,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
} else { } else {
#pragma unroll #pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) { for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
summary_cur[m_idx*2].u64 = summary[m_idx*2].u64; summary_cur[m_idx*2].u64 = summary[m_idx*2].u64;
#pragma unroll #pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) { for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
...@@ -270,7 +270,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA ...@@ -270,7 +270,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread for(int vec_idx=0; vec_idx<4; vec_idx++) { // mmac min_tile is 16*16, a warp is 64 thread
__float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]}; __float2 additem_pair = {tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2].f32[vec_idx], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + 1].f32[vec_idx]};
summary_cur[m_idx*2].u64 = hcu_pk_add_f32( summary_cur[m_idx*2].u64 = __builtin_hcu_pk_add_f32(
summary_cur[m_idx*2].u64, summary_cur[m_idx*2].u64,
additem_pair additem_pair
); );
...@@ -370,15 +370,14 @@ inline __device__ void int8_kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M/32 ...@@ -370,15 +370,14 @@ inline __device__ void int8_kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M/32
// instruction instead of fadd and fmul separately. // instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16 // min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for(int vec_idx = 0; vec_idx < 2; vec_idx++) { for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx] = hcu_pk_fma_f32( tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_fma_f32(
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx], tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].u64[vec_idx],
scale_pair, scale_pair,
neg_max_scaled_pair neg_max_scaled_pair
); );
} }
asm volatile("s_nop 0" ::: "memory");
for(int vec_idx = 0; vec_idx < 4; vec_idx++) { for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]); tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] = __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
} }
...@@ -479,10 +478,10 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3 ...@@ -479,10 +478,10 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16 // min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// 936 及之后的架构有 pk_mul 指令 // 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
#pragma unroll #pragma unroll
for(int vec_idx = 0; vec_idx < 2; vec_idx++) { for(int vec_idx = 0; vec_idx < 2; vec_idx++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx] = hcu_pk_mul_f32( acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32(
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx], acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m].u64[vec_idx],
scores_scale_pair scores_scale_pair
); );
...@@ -534,8 +533,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3 ...@@ -534,8 +533,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
#pragma unroll #pragma unroll
for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++) { for(int warp_loop=1; warp_loop<WARP_NUM; warp_loop++) {
__float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop*WARP_M + mi*32 + lane_id*2); __float2 other_warp_sum = *(__float2*)(sum_lds + warp_loop*WARP_M + mi*32 + lane_id*2);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
cur_wave_sum = hcu_pk_add_f32(cur_wave_sum, other_warp_sum); cur_wave_sum = __builtin_hcu_pk_add_f32(cur_wave_sum, other_warp_sum);
#else #else
cur_wave_sum[0] += other_warp_sum[0]; cur_wave_sum[0] += other_warp_sum[0];
cur_wave_sum[1] += other_warp_sum[1]; cur_wave_sum[1] += other_warp_sum[1];
...@@ -559,8 +558,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3 ...@@ -559,8 +558,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
} }
for (int mi = 0; mi < (WARP_M/32); ++mi) { for (int mi = 0; mi < (WARP_M/32); ++mi) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
scores_sum[mi].u64 = hcu_pk_add_f32( scores_sum[mi].u64 = __builtin_hcu_pk_add_f32(
scores_sum[mi].u64, scores_sum[mi].u64,
scores_sum_cur[mi].u64 scores_sum_cur[mi].u64
); );
......
...@@ -38,20 +38,17 @@ __forceinline__ __device__ void kvcache_acco_reduce( ...@@ -38,20 +38,17 @@ __forceinline__ __device__ void kvcache_acco_reduce(
union_vec2_fp32 acc_tmp; union_vec2_fp32 acc_tmp;
int lds_offset0 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 0 * 16 + (lane_id >> 4) * 4 + WARP_ID; int lds_offset0 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 0 * 16 + (lane_id >> 4) * 4 + WARP_ID;
int lds_offset1 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 1 * 16 + (lane_id >> 4) * 4 + WARP_ID; int lds_offset1 = min_tile_m * __kHeadDim + q_seq_idx * 2 * __kHeadDim + h_idx * kBlockK + k_idx * 32 + 1 * 16 + (lane_id >> 4) * 4 + WARP_ID;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0, acc_tmp.u64, 0, 16); acc_tmp.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0, 0, 16, false);
union_vec2_fp32 acc_tmp_wave1; union_vec2_fp32 acc_tmp_wave1;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 1 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave1.u64, 0, 16); acc_tmp_wave1.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 1 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp.f32[0] += acc_tmp_wave1.f32[0]; acc_tmp.f32[0] += acc_tmp_wave1.f32[0];
acc_tmp.f32[1] += acc_tmp_wave1.f32[1]; acc_tmp.f32[1] += acc_tmp_wave1.f32[1];
union_vec2_fp32 acc_tmp_wave2; union_vec2_fp32 acc_tmp_wave2;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 2 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave2.u64, 0, 16); acc_tmp_wave2.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 2 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp.f32[0] += acc_tmp_wave2.f32[0]; acc_tmp.f32[0] += acc_tmp_wave2.f32[0];
acc_tmp.f32[1] += acc_tmp_wave2.f32[1]; acc_tmp.f32[1] += acc_tmp_wave2.f32[1];
union_vec2_fp32 acc_tmp_wave3; union_vec2_fp32 acc_tmp_wave3;
inlineasm_fa_ds_read2_b32(acc_o_lds, lds_offset0 + 3 * EVEN_REUSE_KV_TIMES * __kHeadDim, acc_tmp_wave3.u64, 0, 16); acc_tmp_wave3.u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)acc_o_lds + lds_offset0 + 3 * EVEN_REUSE_KV_TIMES * __kHeadDim, 0, 16, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
acc_tmp.f32[0] += acc_tmp_wave3.f32[0]; acc_tmp.f32[0] += acc_tmp_wave3.f32[0];
acc_tmp.f32[1] += acc_tmp_wave3.f32[1]; acc_tmp.f32[1] += acc_tmp_wave3.f32[1];
// ds_write2_b32 // ds_write2_b32
...@@ -95,4 +92,4 @@ __forceinline__ __device__ void kvcache_acco_reduce( ...@@ -95,4 +92,4 @@ __forceinline__ __device__ void kvcache_acco_reduce(
} }
} }
} }
} }
\ No newline at end of file
...@@ -2,15 +2,15 @@ ...@@ -2,15 +2,15 @@
#include "numeric_types.h" #include "numeric_types.h"
template<int REUSE_KV_TIMES, int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, int Padding, typename ElementAccum> template<int K_LOOP_COUNT, int K_WARP_COUNT, int M_WARP_COUNT, int M_MMAC_COUNT, int WARP_NUM, int Padding, typename ElementAccum>
__forceinline__ __device__ void kvcache_acco_reduce_tile16x32( __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4], vec4_Accum < ElementAccum> acc_o[K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT][4],
ElementAccum* acc_o_lds, ElementAccum* acc_o_lds,
int seqlen_q, int seqlen_q,
int warp_id, int warp_id,
int lane_id) { int lane_id) {
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__)
constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and Padding == 0); // Specialized optimizatio for headdim 128 constexpr int OPT_FOR_HDIM128 = bool(WARP_NUM == 4 and M_MMAC_COUNT == 1 and Padding == 0 and K_LOOP_COUNT == WARP_NUM and K_WARP_COUNT == 1 and M_WARP_COUNT == 1); // Specialized optimization for headdim 128
#else #else
constexpr int OPT_FOR_HDIM128 = false; // keep same as origin for archs <= gfx936 constexpr int OPT_FOR_HDIM128 = false; // keep same as origin for archs <= gfx936
#endif #endif
...@@ -78,69 +78,70 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32( ...@@ -78,69 +78,70 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
} else { } else {
constexpr int kBlockK = K_WARP_COUNT * 32 + Padding; constexpr int kBlockK = K_WARP_COUNT * 32 + Padding;
// when REUSE_KV not in templated, compute max reuse times int EVEN_REUSE_KV_TIMES = ((seqlen_q + 1) / 2) * 2;
int EVEN_REUSE_KV_TIMES = (REUSE_KV_TIMES > 0) ? ((REUSE_KV_TIMES + 1) / 2) * 2: ((seqlen_q + 1) / 2) * 2;
int q_seq_idx = (lane_id & 15); int q_seq_idx = (lane_id & 15);
if (q_seq_idx < EVEN_REUSE_KV_TIMES) { if (q_seq_idx < EVEN_REUSE_KV_TIMES) {
for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) { for (int h_idx = 0; h_idx < K_LOOP_COUNT; ++h_idx) {
// #################################################################################################################################################### // ####################################################################################################################################################
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int lds_offset = (warp_id * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT + q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4/*0~3*/) * 4/*0~15*/;
int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT + k_idx * M_WARP_COUNT;
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32;
}
}
}
__syncthreads();
// 在 lds 中求和, 把 4 个 wave 写的 acc_o 的数据加起来
if constexpr (WARP_NUM == 4) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) { for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + warp_id; // 之前是一次性写了 4 个 Half 到 lds, 现在 4 个 wave 分别处理这 4 个位置的 acc_o reduce
int lds_offset = (warp_id * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT + q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4/*0~3*/) * 4/*0~15*/; float acc_tmp_wave0 = acc_o_lds[lds_offset];
*(vec4_fp32*)(acc_o_lds + lds_offset) = acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32; for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
} }
} }
} }
__syncthreads(); }
// 在 lds 中求和, 把 4 个 wave 写的 acc_o 的数据加起来 // 不是恰好 4 个 wave, 则把 wave 0 单独拎出来做 lds reduce 操作
if constexpr (WARP_NUM == 4) { else if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) { for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + warp_id; // 之前是一次性写了 4 个 Half 到 lds, 现在 4 个 wave 分别处理这 4 个位置的 acc_o reduce for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
float acc_tmp_wave0 = acc_o_lds[lds_offset]; int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + vec_idx;
for (int loop = 1; loop < WARP_NUM; ++loop) { float acc_tmp_wave0 = acc_o_lds[lds_offset];
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK]; for (int loop = 1; loop < WARP_NUM; ++loop) {
} acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK];
acc_o_lds[lds_offset] = acc_tmp_wave0;
}
}
}
}
// 不是恰好 4 个 wave, 则把 wave 0 单独拎出来做 lds reduce 操作
else if constexpr (WARP_NUM > 1) {
if (warp_id == 0) {
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4 + vec_idx;
float acc_tmp_wave0 = acc_o_lds[lds_offset];
for (int loop = 1; loop < WARP_NUM; ++loop) {
acc_tmp_wave0 += acc_o_lds[lds_offset + loop * EVEN_REUSE_KV_TIMES * M_MMAC_COUNT * kBlockK];
}
acc_o_lds[lds_offset] = acc_tmp_wave0;
} }
acc_o_lds[lds_offset] = acc_tmp_wave0;
} }
} }
} }
} }
} }
__syncthreads(); }
// 每个 wave 都从 LDS 获取最终的求和结果 __syncthreads();
for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) { // 每个 wave 都从 LDS 获取最终的求和结果
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) { for (int k_idx = 0; k_idx < K_WARP_COUNT; ++k_idx) {
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4; for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
acc_o[h_idx * (K_WARP_COUNT + k_idx) * M_WARP_COUNT][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset); int lds_offset = (q_seq_idx + min_tile_m * 16) * kBlockK + k_idx * 32 + min_tile_n * 16 + (lane_id >> 4) * 4;
} int tile_32x32_id = h_idx * M_WARP_COUNT * K_WARP_COUNT + k_idx * M_WARP_COUNT;
acc_o[tile_32x32_id][min_tile_n * 2 + min_tile_m].f32 = *(vec4_fp32*)(acc_o_lds + lds_offset);
} }
} }
__syncthreads();
} }
__syncthreads();
} }
} }
} }
\ No newline at end of file }
...@@ -21,9 +21,9 @@ __forceinline__ __device__ void kvcache_epilugue_rescale_acco( ...@@ -21,9 +21,9 @@ __forceinline__ __device__ void kvcache_epilugue_rescale_acco(
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int mmac_id = min_tile_n * 2 + min_tile_m; int mmac_id = min_tile_n * 2 + min_tile_m;
int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi); int tile_32x32_id = pv_n_loop * M_WARP_COUNT * K_WARP_COUNT + (ni * M_WARP_COUNT + mi);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for (int vec_id = 0; vec_id < 2; ++vec_id) { for (int vec_id = 0; vec_id < 2; ++vec_id) {
acc_o[tile_32x32_id][mmac_id].u64[vec_id] = hcu_pk_mul_f32( acc_o[tile_32x32_id][mmac_id].u64[vec_id] = __builtin_hcu_pk_mul_f32(
acc_o[tile_32x32_id][mmac_id].u64[vec_id], acc_o[tile_32x32_id][mmac_id].u64[vec_id],
scale_pair scale_pair
); );
...@@ -54,12 +54,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_max_sum( ...@@ -54,12 +54,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_max_sum(
int headdim_split_id, int headdim_split_id,
int seqlen_q_limit int seqlen_q_limit
) { ) {
#ifdef FA_DEBUG_SUM_MAX if constexpr (Split) {
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16; bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可 if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll #pragma unroll
...@@ -96,12 +91,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_max_sum( ...@@ -96,12 +91,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_max_sum(
int total_q, int total_q,
int ngroups int ngroups
) { ) {
#ifdef FA_DEBUG_SUM_MAX if constexpr (Split) {
constexpr bool ALLOW_WRITE_SUM_MAX = true;
#else
constexpr bool ALLOW_WRITE_SUM_MAX = false;
#endif
if constexpr (Split or ALLOW_WRITE_SUM_MAX) {
bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16; bool write_ok = Is_16x32 ? (thread_id < 16 and headdim_split_id == 0): thread_id < 16;
if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可 if (write_ok) { // 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll #pragma unroll
...@@ -191,7 +181,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_output( ...@@ -191,7 +181,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_output(
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m); int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m);
if (seqlen_q_idx < params.seqlen_q) { if (seqlen_q_idx < params.seqlen_q) {
if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4% if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int vec_index = warp_id; int vec_index = warp_id;
int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2; int64_t pv_global_addr = seqlen_q_idx * output_seqlen_stride + k_loop * kBlockK + k_tile_idx * 32 + vec_index * 8 + pv_lane_head_dim_idx * 2;
vec2_Element<SplitkvAccumType> result = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[vec_index], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[vec_index]); vec2_Element<SplitkvAccumType> result = DownCastPairNoPack<ElementAccum, SplitkvAccumType>(acc_o[tile_32x32_id][min_tile_m + 0 * 2].f32[vec_index], acc_o[tile_32x32_id][min_tile_m + 1 * 2].f32[vec_index]);
...@@ -264,7 +254,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output( ...@@ -264,7 +254,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output(
int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m); int seqlen_q_idx = m_block * kBlockM + warp_m_idx * 32 + (Is_16x32 ? pv_lane_seq_idx + min_tile_m * 16: pv_lane_seq_idx * 2 + min_tile_m);
if (seqlen_q_idx < actual_seqlen_q) { if (seqlen_q_idx < actual_seqlen_q) {
if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4% if constexpr (WARP_NUM == 4) { // for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int vec_index = warp_id; int vec_index = warp_id;
int true_seqlen_q = seqlen_q_idx / params.ngroups; int true_seqlen_q = seqlen_q_idx / params.ngroups;
int true_group_id = seqlen_q_idx % params.ngroups; int true_group_id = seqlen_q_idx % params.ngroups;
......
...@@ -74,7 +74,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32( ...@@ -74,7 +74,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4; precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + seq_idx * 32 * kBlockN + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
} }
} }
} }
...@@ -97,7 +97,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32( ...@@ -97,7 +97,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll #pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET); v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
} }
} }
} }
...@@ -168,7 +168,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32( ...@@ -168,7 +168,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
precompute_v_lds_offset[vec_idx] = reinterpret_cast<size_t>(v_lds_v2fp16) + ((stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2) * 4; precompute_v_lds_offset[vec_idx] = (stage_id * WARP_K * kBlockN + (seq_idx * 32 * kBlockN) + head_dim_idx * 32 * 32 + vec_idx * 8 * 32 + v_ds_read_offset) / 2;
} }
} }
} }
...@@ -191,7 +191,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32( ...@@ -191,7 +191,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) { for (int seq_idx = 0; seq_idx < PV_K_WARP_COUNT; ++seq_idx) {
#pragma unroll #pragma unroll
for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) { for (int head_dim_idx = 0; head_dim_idx < PV_N_WARP_COUNT; ++head_dim_idx) {
inline_ds_read2_b32_no_wait_bytes(precompute_v_lds_offset[vec_idx], v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64, NEXT_DWORD_OFFSET); v_reg[stage_id * PV_K_WARP_COUNT * PV_N_WARP_COUNT + (head_dim_idx * PV_K_WARP_COUNT + seq_idx)][vec_idx].u64 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float *)v_lds_v2fp16 + precompute_v_lds_offset[vec_idx], 0, NEXT_DWORD_OFFSET, false);
} }
} }
} }
......
...@@ -39,7 +39,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v( ...@@ -39,7 +39,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v(
int stage_id = 0; int stage_id = 0;
// load 指令发下去之后, 先做一些初始化运算 // load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
if constexpr (M_MMAC_COUNT == 1) { if constexpr (M_MMAC_COUNT == 1) {
inline_vgpr4_init_zero_1x2x4(s_reg); inline_vgpr4_init_zero_1x2x4(s_reg);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment