#pragma once #include #include #include #include #include #include #include #include #include #include "defines.h" #include "params.h" #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while(0) #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) #define FLASH_ASSERT(cond) \ do { \ if (not (cond)) { \ fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ exit(1); \ } \ } while(0) #define FLASH_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) { \ printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ asm volatile("s_trap 0 \n\t"); \ } \ } while(0) #define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } template __inline__ __host__ __device__ T ceil_div(const T &a, const T &b) { return (a + b - 1) / b; } #ifndef TRAP_ONLY_DEVICE_ASSERT #define TRAP_ONLY_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) \ asm("trap;"); \ } while (0) #endif #ifndef TRAP_ONLY_DEVICE_ASSERT #define TRAP_ONLY_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) \ asm("trap;"); \ } while (0) #endif struct RingBufferState { uint32_t cur_block_idx = 0u; __device__ __forceinline__ void update() { cur_block_idx += 1; } template __device__ __forceinline__ std::pair get() const { uint32_t stage_idx = cur_block_idx % NUM_STAGES; bool phase = (cur_block_idx / NUM_STAGES) & 1; return {stage_idx, phase}; } __device__ __forceinline__ RingBufferState offset_by(const int offset) const { // Must guarantee no underflow uint32_t new_block_idx = static_cast(static_cast(cur_block_idx) + offset); RingBufferState new_state; new_state.cur_block_idx = new_block_idx; return new_state; } }; #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor(x, OFFSET, 64)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<1> { // static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2); template static __device__ __forceinline__ T run(T x, Operator &op) { return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<32> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor(x, 16, 64)); return x; } }; template __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, int begin_k=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } } template __forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast<__fp16 *>(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, r_row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } template __forceinline__ __device__ void __ds_read_m32x16_row_col_rrow_alt(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast<__fp16 *>(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(lds), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, r_row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } template __forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast<__fp16 *>(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } inline __device__ float fp8e4m3_to_fp32(const fp8& input) { const uint32_t w = (uint32_t)input << 24; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); uint32_t renorm_shift = __clz(nonsign); renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)); union { uint32_t as_bits; float as_value; } fp32 = {result}; return fp32.as_value; } template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { // static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_1>{}); // (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4) return make_layout(make_layout(get<1>(l)), make_layout(get<1>(get<0>(l)), get<2>(l))); // (1, (4, 2)):((_0),(_1,_4)) }; template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; if constexpr (std::is_same_v) { return tensor; } constexpr int numel = decltype(size(tensor))::value; Tensor tensor_To_type = make_tensor(layout(tensor)); cutlass::Array *result_ptr = reinterpret_cast *>(tensor_To_type.data()); #if defined(__gfx938__) { if constexpr (std::is_same_v) { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else if constexpr (std::is_same_v) { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } return tensor_To_type; } #else { if constexpr (std::is_same_v) { #ifndef FLASH_MLA_BF16_TYPE #define FLASH_MLA_BF16_TYPE 0 #endif #if FLASH_MLA_BF16_TYPE == 0 cutlass::NumericArrayConverter convert_op; #else cutlass::NumericArrayConverter convert_op; #endif *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } return tensor_To_type; } #endif // cutlass::NumericArrayConverter convert_op; // // HACK: this requires tensor to be "contiguous" // auto frag = convert_op(*reinterpret_cast *>(tensor.data())); // return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template __forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc) { using Value_type = typename Engine0::value_type; int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0); __syncthreads(); using SmemLayoutAtomP = Layout, Int<64>>, Stride, _1>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape, Int<64>>{})); Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{}); auto thr_mma = tiled_mma_o.get_thread_slice(tid); Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp); tSrACC(0, 0, 0) = sAcc(tid * 8 + 0); tSrACC(1, 0, 0) = sAcc(tid * 8 + 1); tSrACC(2, 0, 0) = sAcc(tid * 8 + 2); tSrACC(3, 0, 0) = sAcc(tid * 8 + 3); tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4); tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4); tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4); tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4); tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32); tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32); tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32); tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32); tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32); tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32); tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32); tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32); return tSrACC; } __forceinline__ __device__ bool is_positive_infinity(const float& f_val) { union Fp32{ uint32_t as_bits; float as_value; }; Fp32 fp32; fp32.as_value = f_val; Fp32 inf_tmp; inf_tmp.as_value = INFINITY; return fp32.as_bits == inf_tmp.as_bits; } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { #if defined(__gfx936__) || defined(__gfx938__) { if constexpr (Is_load_Q) { // // 32x64 constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 16*128; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); } else { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 32*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8; int virtual_col = lane % 8; int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4; // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); } } #endif } template CUTE_HOST_DEVICE void lds_direct_copy_for_prefill_sparse_mla( Tensor const& src, Tensor & dst, int row_offset, int col, int k_idx_, const int row_stride, int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); glob_ptr.latter |= ((row_stride * 2) << 16); // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = max_MN; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 32*64; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (col_offset) * element_size; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 4) * mma_k * element_size; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset == -1 ? max_MN : row_offset; index_offset[1] = offset_v; asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy( Tensor const& src, uint128_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 8; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } else { uint32x4_t global_addr = {0}; *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx / 4; int col = lane % 4; int row_offset = row; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } template< class SrcEngine, class SrcLayout> CUTE_HOST_DEVICE void buffer_to_tensor(const uint128_t & src, Tensor & dst, int k_idx) { uint128_t* d = reinterpret_cast(&dst(0, 0, k_idx)); d[0] = src; } template __forceinline__ __device__ auto convert_layout_acc_Aregs_dense(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc) { using Value_type = typename Engine0::value_type; int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; // __fp16 *smem_ptr = // sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0); __syncthreads(); using SmemLayoutAtomP = Layout, Int<64>>, Stride, _1>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape, Int<64>>{})); Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{}); auto thr_mma = tiled_mma_o.get_thread_slice(tid); Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp); tSrACC(0, 0, 0) = sAcc(tid * 8 + 0); tSrACC(1, 0, 0) = sAcc(tid * 8 + 1); tSrACC(2, 0, 0) = sAcc(tid * 8 + 2); tSrACC(3, 0, 0) = sAcc(tid * 8 + 3); tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4); tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4); tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4); tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4); tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32); tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32); tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32); tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32); tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32); tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32); tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32); tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32); return tSrACC; } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { if constexpr (Is_load_Q) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; int mma_k = 16*256; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); } else { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int mma_k = 64*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8;//0 int virtual_col = lane % 8;//0 int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4;//0 // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; //int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size; #if defined(__gfx938__) asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); #endif } } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_qkvfp8( Tensor const& src, uint128_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 16; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } template __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst) { auto lds = reinterpret_cast(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 1; auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset); dst = d; } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_fp8( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { if constexpr (Is_load_Q) { } else { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; int mma_k = 64*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8; int virtual_col = lane % 8; int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4; // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; // if (thread(0)) // { // printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size); // } #if defined(__gfx936__) || defined(__gfx938__) asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); #endif } } __forceinline__ __device__ cutlass::bfloat16_t fp8e4m3_to_bf16(const fp8& input) { const uint16_t w = (uint16_t)input << 8; const uint16_t sign = w & UINT16_C(0x8000); const uint16_t nonsign = w & UINT16_C(0x7FFF); constexpr uint16_t exp_offset=(0x78 << 7); uint16_t result = sign | ((nonsign >> 4) + exp_offset); // if(nonsign == 0x0000) result = 0x0000; // if (thread0() && nonsign == 0x0000) // { // printf(" input = %x result = %x\n", input, result); // } return cutlass::bfloat16_t::bitcast(result); } __forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) { union uf16{ uint16_t as_bits; _Float16 as_value; } ; union uf32 { uint32_t as_bits; float as_value; }; uf16 u16; uf32 u32; u16.as_bits = (uint16_t)input << 8; u32.as_value = (float)u16.as_value; // return u32.as_bits>>16; return u32.as_value; } __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) { union uf16{ uint16_t as_bits; __fp16 as_value; } ; union uf32 { uint32_t as_bits; float as_value; }; uf16 u16; // uf32 u32; // u16.as_bits = (uint16_t)input << 8; // u32.as_value = (float)u16.as_value; // return u32.as_bits>>16; uint16_t output = (uint16_t)(input << 8); return cutlass::half_t::bitcast(output); } template CUTE_HOST_DEVICE void wait_vmcnt() { asm volatile("s_waitcnt vmcnt(%0) ;\n\t" "s_barrier; \n\t" :: "n"(N)); } #if 0 template __forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B, const float& k_scale ) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; union { __fp16x8_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data[8]; __builtin_amdgcn_sched_barrier(0); wait_vmcnt<8>(); data[0].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 0)); wait_vmcnt<7>(); data[1].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 1)); wait_vmcnt<6>(); data[2].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 2)); wait_vmcnt<5>(); data[3].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 3)); wait_vmcnt<4>(); data[4].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 4)); wait_vmcnt<3>(); data[5].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 5)); wait_vmcnt<2>(); data[6].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 6)); wait_vmcnt<1>(); data[7].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 7)); __builtin_amdgcn_sched_barrier(0); #pragma unroll for (int k_idx = 0; k_idx < 8; k_idx++) { #pragma unroll for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx].fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } // if (block0()) // { // printf("threadIdx.x = %d %.2f %.2f %.2f %.2f \n", threadIdx.x, f1, f2, f3, f4); // } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { // auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); // auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); // auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); // auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); // tCrB(j, 0, k_idx) = f1; // tCrB(j + 1, 0, k_idx) = f2; // tCrB(j + 2, 0, k_idx) = f3; // tCrB(j + 3, 0, k_idx) = f4; __hip_fp8x4_storage_t fp8_data = data[k_idx].fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; }; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, 0, k_idx) = result0[0]; tCrB(j + 1, 0, k_idx) = result0[1]; tCrB(j + 2, 0, k_idx) = result1[0]; tCrB(j + 3, 0, k_idx) = result1[1]; } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } } template __forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, TiledMma tiled_mma, uint32x4_t& _data, const float& k_scale) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; union { uint32x4_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data; data.data_128 = _data; for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data.fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data.fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } // if (thread0()) { // printf(" static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8) = %x f1 = %.2f\n", static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), f1); // } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { __hip_fp8x4_storage_t fp8_data = data.fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; } ; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, 0, k_idx) = result0[0]; tCrB(j + 1, 0, k_idx) = result0[1]; tCrB(j + 2, 0, k_idx) = result1[0]; tCrB(j + 3, 0, k_idx) = result1[1]; } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } #endif template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_fp8( Tensor const& src, uint32x4_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 16; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } #if 0 template __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B, const float& k_scale ) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; auto lds = reinterpret_cast<__fp16 *>(&tCsB(0, 0, 0)); auto layout = tCsB.layout(); union { __fp16x8_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data[8]; constexpr short offset0 = layout(0, 0, 0) * 2; data[0].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset0); constexpr short offset1 = layout(0, 1, 0) * 2; data[1].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset1); constexpr short offset2 = layout(0, 0, 1) * 2; data[2].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset2); constexpr short offset3 = layout(0, 1, 1) * 2; data[3].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset3); constexpr short offset4 = layout(0, 0, 2) * 2; data[4].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset4); constexpr short offset5 = layout(0, 1, 2) * 2; data[5].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset5); constexpr short offset6 = layout(0, 0, 3) * 2; data[6].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset6); constexpr short offset7 = layout(0, 1, 3) * 2; data[7].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset7); #pragma unroll for (int k_idx = 0; k_idx < 4; k_idx++) { #pragma unroll for (int i = 0; i < 2; i++) { #pragma unroll for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx * 2 + i].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx * 2 + i].fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { // cutlass::NumericConverter convert_; auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, i, k_idx) = rst0; tCrB(j + 1, i, k_idx) = rst1; tCrB(j + 2, i, k_idx) = rst2; tCrB(j + 3, i, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, i, k_idx) = rst0; tCrB(j + 1, i, k_idx) = rst1; tCrB(j + 2, i, k_idx) = rst2; tCrB(j + 3, i, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { // auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); // auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); // auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); // auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); // tCrB(j, i, k_idx) = f1; // tCrB(j + 1, i, k_idx) = f2; // tCrB(j + 2, i, k_idx) = f3; // tCrB(j + 3, i, k_idx) = f4; __hip_fp8x4_storage_t fp8_data = data[k_idx * 2 + i].fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; } ; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v){ auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f3); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f2, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, i, k_idx) = result0[0]; tCrB(j + 1, i, k_idx) = result1[0]; tCrB(j + 2, i, k_idx) = result0[1]; tCrB(j + 3, i, k_idx) = result1[1]; } } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } } #endif typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); template __forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds_read_ptr, v4f* accs_f32) { typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4))); union Bf16_storage { __fp16x8_t data_128; __fp16x4_t data_64[2]; uint16_t data_array[8]; }; constexpr int k_idx_even = k_idx % 4; constexpr int n_offset = 16 * 32; constexpr int k_offset = k_idx_even * 64 * 32; Bf16_storage q_reg; Bf16_storage k_reg; q_reg.data_128 = q_data; k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset); // q_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(q_lds_read_ptr), k_offset, 2, 1, 0); // k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 0 * n_offset + k_offset, 2, 1, 0); #if defined(__gfx938__) accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0], true,false); accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0], true,false); #else accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0]); accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0]); #endif k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 1 * n_offset); #if defined(__gfx938__) accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1], true,false); accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1], true,false); #else accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1]); accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1]); #endif // k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 1 * n_offset + k_offset, 2, 1, 0); k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 2 * n_offset); // k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 2 * n_offset + k_offset, 2, 1, 0); #if defined(__gfx938__) accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2], true,false); accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2], true,false); #else accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2]); accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2]); #endif k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 3 * n_offset); // k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 3 * n_offset + k_offset, 2, 1, 0); #if defined(__gfx938__) accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3], true,false); accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3], true,false); #else accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3]); accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3]); #endif } typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4))); template __forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr, v4f* acco_f32) { constexpr int k_idx_even = k_idx % 1; constexpr int n_offset = 16 * 32 * 2; typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); union Bf16_storage { __fp16x8_t data_128; __fp16x4_t data_64[2]; uint16_t data_array[8]; }; constexpr int k_offset = k_idx_even * 16 * 512 * 2; // #if 1 Bf16_storage v_reg; v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + n_idx_val * n_offset); #if defined(__gfx938__) acco_f32[n_idx_val * 2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[0], acco_f32[n_idx_val * 2], true, false); acco_f32[n_idx_val * 2 + 1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1], true, false); #else acco_f32[n_idx_val * 2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[0], acco_f32[n_idx_val * 2]); acco_f32[n_idx_val * 2 + 1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1]); #endif } }