// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h #pragma once #include #include #include // #include // #include #include #include #include #include #include "flash_mla.h" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor(x, OFFSET, 64)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<1> { // static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2); template static __device__ __forceinline__ T run(T x, Operator &op) { return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<32> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor(x, 16, 64)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; if constexpr (std::is_same_v) { return tensor; } constexpr int numel = decltype(size(tensor))::value; Tensor tensor_To_type = make_tensor(layout(tensor)); cutlass::Array *result_ptr = reinterpret_cast *>(tensor_To_type.data()); #if defined(__gfx938__) { if constexpr (std::is_same_v) { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else if constexpr (std::is_same_v) { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } return tensor_To_type; } #else { if constexpr (std::is_same_v) { #ifndef FLASH_MLA_BF16_TYPE #define FLASH_MLA_BF16_TYPE 0 #endif #if FLASH_MLA_BF16_TYPE == 0 cutlass::NumericArrayConverter convert_op; #else cutlass::NumericArrayConverter convert_op; #endif *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } else { cutlass::NumericArrayConverter convert_op; *result_ptr = convert_op(*reinterpret_cast *>(tensor.data())); } return tensor_To_type; } #endif // cutlass::NumericArrayConverter convert_op; // // HACK: this requires tensor to be "contiguous" // auto frag = convert_op(*reinterpret_cast *>(tensor.data())); // return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast<__fp16 *>(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } template __forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast<__fp16 *>(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, r_row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { // static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_1>{}); // (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4) return make_layout(make_layout(get<1>(l)), make_layout(get<1>(get<0>(l)), get<2>(l))); // (1, (4, 2)):((_0),(_1,_4)) }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, Tensor const& tOrP, Tensor const& sAcc) { using Value_type = typename Engine0::value_type; int tidx = threadIdx.x; auto thr_mma = tiled_mma.get_thread_slice(tidx); auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom{}, tiled_mma); auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx); Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP); Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc); // if (cute::thread0()) // { taccOr // raw_ptr_16b(0x2000000000010) o ((_1,_4),_1,_4):((_0,_1),_0,_4) // print("taccOr\n"); print(taccOr); print("\n"); // } cute::copy(smem_tiled_copy_ACC, taccOr, taccOs); // asm volatile("s_waitcnt lgkmcnt(0)\n\t"); __syncthreads(); auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc); Tensor tSrACC = thr_mma.partition_fragment_A(sAcc); Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC); cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view); // asm volatile("s_waitcnt lgkmcnt(0)\n\t"); // __syncthreads(); // 取消这个sync,2024.06.13 return tSrACC; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, int begin_k=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } } template __forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, int k_idx=0, int k_idx_smem=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M // CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { if (Is_even_K || predicate_K(k_idx)) { cute::copy(tiled_copy, S(_, m, k_idx), D(_, m, k_idx_smem)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k_idx)); } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } } template CUTE_HOST_DEVICE void wait_vmcnt() { __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(%0) ;\n\t" "s_barrier; \n\t" :: "n"(N)); __builtin_amdgcn_sched_barrier(0); } template< class SrcEngine, class SrcLayout> CUTE_HOST_DEVICE void asm_ds_write(const uint128_t & src, Tensor & dst, int k_idx) { uint128_t* d = reinterpret_cast(&dst(0, 0, k_idx)); d[0] = src; } template< class SrcEngine, class SrcLayout> CUTE_HOST_DEVICE void buffer_to_tensor(const uint128_t & src, Tensor & dst, int k_idx) { uint128_t* d = reinterpret_cast(&dst(0, 0, k_idx)); d[0] = src; } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy( Tensor const& src, uint128_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 8; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } else { uint32x4_t global_addr = {0}; *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx / 4; int col = lane % 4; int row_offset = row; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_fp8( Tensor const& src, uint32x4_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 16; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if constexpr(use_asm) { __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_fp8x2( Tensor const& src, uint32x2_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 8; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = lane / 8; int col = lane % 8; int row_offset = row * 4 + ((warp_id % 4)) + offset_k; int col_offset = col * elements_per_thread + (warp_id / 4 ) * 64 + k_idx * 128; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (col_offset >= 576) offset_v = -1; if constexpr(use_asm) { __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(offset_v), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, 0, offset_v, false, false); dst = *reinterpret_cast(&res); } } } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_qkvfp8( Tensor const& src, intx4_t & dst, int k_idx_, const int row_stride, int offset_k, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 16; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride glob_ptr.latter |= ((row_stride) << 16); uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = !Is_even_MN ? max_MN : 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = tidx % 16; int col = lane / 16; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // uint32x2_t index_offset = {0}; // index_offset[0] = row_offset; // index_offset[1] = col_offset; if constexpr(use_asm) { // asm volatile( // "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" // " \n\t" :"=v"(dst), // "+v"(offset_v), "+s"(global_addr) // ); } else { dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false); } } else { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride glob_ptr.latter |= ((row_stride) << 16); constexpr int elements_per_thread = 16; uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = !Is_even_MN ? max_MN : 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = lane / 4; int col = lane % 4; int row_offset = row + (warp_id * 16); int col_offset = col * elements_per_thread + k_idx * 64; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false); } } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8_pe( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { if constexpr (Is_load_Q) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; int mma_k = 16*256; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { if constexpr (Is_load_Q) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; int mma_k = 16*256; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } else { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int mma_k = 64*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8;//0 int virtual_col = lane % 8;//0 int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4;//0 // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; //int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size; #if defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_fp8( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { if constexpr (Is_load_Q) { } else { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; int mma_k = 64*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8; int virtual_col = lane % 8; int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4; // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 64; int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; // if (thread(0)) // { // printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size); // } #if defined(__gfx936__) || defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } } template < bool Is_even_MN=true, bool Is_even_K=true, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_tp1( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, int offset_r, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 32*64; int row = (lane / 8) * 4 + warp_id % 4 + offset_r; int col = (lane % 8); int row_offset = row ; int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id / 4) * (32 * 64 * 2) + (warp_id % 4) * 8 * 64 * 2 + k_idx * 32 * 128 * 2; ldsAddrPerWave |= (((warp_id % 4) * 2) << 16); // if (block0() && lane == 0) // { // printf(" %x \n", ldsAddrPerWave); // } __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } template __forceinline__ __device__ void __ds_read_m32x16_row_col_lds(__fp16* lds_ptr, Tensor1& dst) { // auto lds = reinterpret_cast<__fp16 *>(src.data().get()); // auto layout = src.layout(); constexpr short offset = row * 32 * 64 * 2; auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds_ptr), offset); uint16_t * d_ptr = reinterpret_cast(&d); uint16_t * dst_ptr = reinterpret_cast(&(dst(0, row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; } template __forceinline__ __device__ auto convert_layout_acc_Aregs_tp1(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc) { int tidx = threadIdx.x; int lane_id = tidx % 64; int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64); int row = (tidx % 16) + (warp_id % 4) * 16; int col = lane_id / 16; sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8] = tOrP(0, 0, 0); sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 16 * 8] = tOrP(1, 0, 0); sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 32 * 8] = tOrP(2, 0, 0); sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 48 * 8] = tOrP(3, 0, 0); // sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 0 * 8 + 64 * 32] = tOrP(0, 0, 1); // sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 16 * 8 + 64 * 32] = tOrP(1, 0, 1); // sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 32 * 8 + 64 * 32] = tOrP(2, 0, 1); // sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 48 * 8 + 64 * 32] = tOrP(3, 0, 1); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 8] = tOrP(1, 0, 0); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 16] = tOrP(2, 0, 0); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 24] = tOrP(3, 0, 0); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 32] = tOrP(0, 0, 1); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 8 + 64 * 32] = tOrP(1, 0, 1); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 16 + 64 * 32] = tOrP(2, 0, 1); // sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 24 + 64 * 32] = tOrP(3, 0, 1); // 每个线程写入 4个元素, 0 256 1 257以此类推 // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 0] = tOrP(0, 0, 0); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 1] = tOrP(1, 0, 0); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 2] = tOrP(2, 0, 0); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 3] = tOrP(3, 0, 0); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 0 + 64*32] = tOrP(0, 0, 1); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 1 + 64*32] = tOrP(1, 0, 1); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 2 + 64*32] = tOrP(2, 0, 1); // sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 3 + 64*32] = tOrP(3, 0, 1); // for (int n = 0; n < 2; n++) // { // for (int k = 0; k < 4; j++) // { // sAcc[] // } // } // auto thr_mma = tiled_mma.get_thread_slice(tidx); __syncthreads(); using SmemLayoutAtomP = Layout, Int<32>>, Stride, _1>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape, Int<32>>{})); Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{}); auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); Tensor tSrACC = thr_mma_o.partition_fragment_A(sP_tmp); tSrACC(0, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 0]; tSrACC(1, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 1]; tSrACC(2, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 2]; tSrACC(3, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 3]; tSrACC(0, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 4]; tSrACC(1, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 5]; tSrACC(2, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 6]; tSrACC(3, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 7]; // tSrACC(0, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 0 + 64*32]; // tSrACC(1, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 1 + 64*32]; // tSrACC(2, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 2 + 64*32]; // tSrACC(3, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 3 + 64*32]; // tSrACC(0, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 4 + 64*32]; // tSrACC(1, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 5 + 64*32]; // tSrACC(2, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 6 + 64*32]; // tSrACC(3, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 7 + 64*32]; // if (tidx < 64 && block0()) // { // printf(" %d %.2f %.2f %.2f %.2f\n ", tidx, float(tSrACC(0, 0, 1)), // float(tSrACC(1, 0, 1)), // float(tSrACC(2, 0, 1)), // float(tSrACC(3, 0, 1)) // ); // } return tSrACC; } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_sparse_k( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { #if defined(__gfx936__) || defined(__gfx938__) if constexpr (Is_load_Q) { // // 32x64 constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 16*128; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + 0 * mma_k * element_size; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } #endif } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { #if defined(__gfx936__) || defined(__gfx938__) { if constexpr (Is_load_Q) { // // 32x64 constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 16*128; int row = lane % 16; int col = lane / 16; int row_offset = row ; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } else { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 32*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8; int virtual_col = lane % 8; int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4; // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + (warp_id * 16) ; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } } #elif defined(__gfx928__) { } #endif } template CUTE_HOST_DEVICE void lds_direct_copy_for_prefill_sparse_mla( Tensor const& src, Tensor & dst, int row_offset, int col, int k_idx_, const int row_stride, int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); glob_ptr.latter |= ((row_stride * 2) << 16); // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); global_addr[2] = max_MN; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; constexpr int bytes_per_warp = warp_size * 8 * element_size; int mma_k = 32*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 // int virtual_row = lane / 8; // int virtual_col = lane % 8; // int swizzle_col = virtual_row ^ virtual_col; // int row = lane / 4; // // 8->9 9->8 // row = (row >= 8 ) ^ row; // // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; // int col = swizzle_col % 4; // int row_offset = row + (warp_id * 16) ; // row_offset = gIndices[row_offset]; int col_offset = col * elements_per_thread + k_idx * 32; int offset_v = (col_offset) * element_size; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 4) * mma_k * element_size; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset == -1 ? max_MN : row_offset; index_offset[1] = offset_v; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_sparse_fp8( Tensor const& src, uint64_t & dst, int block_idx, int batch_stride, int row_offset, int col, int k_idx_, const int row_stride, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 8; if constexpr (mma_layout) { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); glob_ptr.latter += (row_stride << 16); // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0xFFFFFFFE; global_addr[3] = 0x00020000; // int row = tidx % 16; // int col = lane / 16; // int row_offset = row + (warp_id * 16) ; uint32_t col_offset = col * elements_per_thread + k_idx * 32; // int offset_v = (((row_offset + 64 ) % 64) * row_stride + col_offset) * element_size; // bytes // int offset_v = (((row_offset + 64 ) % 64) * row_stride + col_offset) * element_size + block_idx * batch_stride; // bytes // uint32_t offset_v = col_offset * element_size + (batch_stride) * block_idx; // bytes uint32_t offset_v = col_offset * element_size; // bytes if (row_offset < 0) offset_v = -1; if constexpr(use_asm) { typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = (row_offset + 64 ) % 64; index_offset[1] = offset_v; __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(index_offset), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { // auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false); dst = *reinterpret_cast(&res); } } } template < bool Is_even_MN=true, bool Is_even_K=true, bool mma_layout = false, bool use_asm = false, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_sparse_decoding( Tensor const& src, uint32x4_t & dst, int block_idx, int batch_stride, int row_offset, int col, int k_idx_, const int row_stride, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 2; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); constexpr int elements_per_thread = 8; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()) + 512 + 16 ; // glob_ptr.latter |= ((2) << 16); // 62 bit: cache swizzle; 48~61: Stride glob_ptr.latter |= ((row_stride) << 16); // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; // global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); // global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[0] = (glob_ptr.former); global_addr[1] = (glob_ptr.latter); // global_addr[2] = 0x80000000; global_addr[2] = 0xFFFFFFFE; global_addr[3] = 0x00020000; int mma_k = 32*64; // int row = tidx % 16; // int col = lane / 16; // int row_offset = row + (warp_id * 16) ; uint32_t col_offset = col * elements_per_thread + k_idx * 32; // int offset_v = ((row_offset % 64 ) * row_stride + col_offset * element_size) + 512 + 16 + block_idx * batch_stride; // bytes // uint32_t offset_v = (col_offset * element_size) + 512 + 16 + block_idx * batch_stride; // bytes uint32_t offset_v = (col_offset * element_size) + ((row_offset + 64 ) % 64 ) * row_stride; // bytes // uint32_t offset_v = (col_offset * element_size); // bytes // uint32_t offset_v = (col_offset * element_size) + ((row_offset + 64 ) % 64 ) * row_stride; // bytes // uint32_t offset_v = (col_offset * element_size) + 512 + 16; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size + 512 + 16; // bytes // if (row_offset == -1) offset_v = -1; if constexpr(use_asm) { typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = (row_offset + 64 ) % 64; index_offset[1] = offset_v; __builtin_amdgcn_sched_barrier(0); asm volatile( "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" " \n\t" :"=v"(dst), "+v"(index_offset), "+s"(global_addr) ); __builtin_amdgcn_sched_barrier(0); } else { // auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (batch_stride / row_stride) * block_idx , offset_v, false, false); // auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64, offset_v, false, false); dst = *reinterpret_cast(&res); } } // /* // for _64x32, use thread layout is 64x4, per thread get 8 elements, get 64x32 data, put data in lds with 32x64 // for _16x128, use thread layout is 16x16, per thread get 8 elements, get 16x128 data, put data in lds with 32x64 // for _16x192, use thread layout is 16x16, per thread get 12 elements, get 16x192 data, put data in lds with 48x64 // for _16x64_128, use thread layout is 16x16, per thread get 4 elements with offset 128, get 16x64 data, put data in lds with 16x64 // */ // enum MMA_LAYOUT{ _64x32 /* for gemm0 load K */, _16x128 /* for gemm1 load V */, _16x192 /* for dim 192 */, _16x64_128 /* for dim 64 */, _16x64_64 /*for load dim 64 V*/ }; // template // CUTE_HOST_DEVICE // void // lds_direct_copy( // Tensor const& src, // Tensor & dst, // int k_idx_, const int row_stride, // const int max_K = 0, const int max_MN=0) // { // constexpr int warp_size = 64; // int tidx = threadIdx.x; // int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); // int lane = tidx % warp_size; // constexpr int element_size = 2; // int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); // int k_slide = k_idx; // if constexpr(K_BUFF_SIZE) { // k_slide = (k_idx % K_BUFF_SIZE); // } // const int offset_s = 0; // // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0x80000000; // global_addr[3] = 0x00020000; // if constexpr(mma_layout == _64x32) { // constexpr int elements_per_thread = 8; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 32*64; // int row = tidx % 16; // int col = lane / 16; // int row_offset = row + (warp_id * 16) ; // int col_offset = col * elements_per_thread + k_idx * 32; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size; // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #else // #endif // } else if constexpr(mma_layout == _16x128) { // constexpr int elements_per_thread = 8; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 16*128; // int row = lane / 4; // int col = tidx % 4; // int row_offset = row + k_idx * 16; // int col_offset = col * elements_per_thread + warp_id * 32; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size; // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // } else if constexpr(mma_layout == _16x192) { // constexpr int elements_per_thread = 8; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 48*64; // int row = lane / 4; // int col = tidx % 4; // int row_offset = row + k_idx * 16; // int col_offset = col * elements_per_thread + warp_id * 32; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size; // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // constexpr int elements_per_thread_tail = 4; // constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size; // row = (tidx / 8) % 16; // col = tidx % 8; // row_offset = row + k_idx * 16; // col_offset = col * elements_per_thread_tail + warp_id / 2 * 32 + /* pre offset */128 ; // offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // ldsAddrPerWave = reinterpret_cast(dst.data().get()) + /* pre offset */64*32 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size; // // if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave); // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // } else if constexpr(mma_layout == _16x64_128) { // #if 0 // constexpr int elements_per_thread = 4; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 16*64; // int row = (tidx / 8) % 16; // int col = tidx % 8; // int row_offset = row + k_idx * 16; // int col_offset = col * elements_per_thread + warp_id / 2 * 32 + /* pre offset */128 ; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size; // // if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave); // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, glc lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // #else // constexpr int elements_per_thread = 8; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 16*64; // int row = lane / 4 + (warp_id / 2) * 16; // int col = tidx % 4; // int row_offset = row + k_idx * 16; // int col_offset = col * elements_per_thread + (warp_id % 2) * 32 + 128; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id % 2) * bytes_per_warp + k_slide * mma_k * element_size + (warp_id/2)*mma_k * element_size ; // // if (tidx < 256) printf("tid:%d offset_v:%d row %d col %d ldsAddrPerWave:%d\n", tidx, offset_v, row_offset, col_offset, (warp_id % 2) * bytes_per_warp + k_slide * mma_k * element_size + (warp_id/2)*mma_k * element_size); // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // #endif // } else if constexpr(mma_layout == _16x64_64) { // constexpr int elements_per_thread = 4; // constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size; // int mma_k = 16*64; // int row = (tidx / 8) % 16; // int col = tidx % 8; // int row_offset = row + k_idx * 16; // int col_offset = col * elements_per_thread + warp_id / 2 * 32; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >= max_K) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size; // // if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave); // #if defined(__gfx936__) // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, glc lds \n" ::"v"(offset_v), // "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) // :); // #endif // } // } #if 1 #define fp8 unsigned char inline __device__ float fp8e4m3_to_fp32(const fp8& input) { // const uint32_t w = (uint32_t)input << 24; // const uint32_t sign = w & UINT32_C(0x80000000); // const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); // uint32_t renorm_shift = __clz(nonsign); // renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; // const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; // uint32_t result = sign | (((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) & (~zero_mask) ); const uint32_t w = (uint32_t)input << 24; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); uint32_t renorm_shift = __clz(nonsign); renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)); union { uint32_t as_bits; float as_value; } fp32 = {result}; return fp32.as_value; } __forceinline__ __device__ cutlass::bfloat16_t fp8e4m3_to_bf16(const fp8& input) { const uint16_t w = (uint16_t)input << 8; const uint16_t sign = w & UINT16_C(0x8000); const uint16_t nonsign = w & UINT16_C(0x7FFF); constexpr uint16_t exp_offset=(0x78 << 7); uint16_t result = sign | ((nonsign >> 4) + exp_offset); // if(nonsign == 0x0000) result = 0x0000; // if (thread0() && nonsign == 0x0000) // { // printf(" input = %x result = %x\n", input, result); // } return cutlass::bfloat16_t::bitcast(result); } __forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) { union uf16{ uint16_t as_bits; _Float16 as_value; } ; union uf32 { uint32_t as_bits; float as_value; }; uf16 u16; uf32 u32; u16.as_bits = (uint16_t)input << 8; u32.as_value = (float)u16.as_value; // return u32.as_bits>>16; return u32.as_value; } __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) { union uf16{ uint16_t as_bits; __fp16 as_value; } ; union uf32 { uint32_t as_bits; float as_value; }; uf16 u16; // uf32 u32; // u16.as_bits = (uint16_t)input << 8; // u32.as_value = (float)u16.as_value; // return u32.as_bits>>16; uint16_t output = (uint16_t)(input << 8); return cutlass::half_t::bitcast(output); } #else #endif template __forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B, const float& k_scale ) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; union { __fp16x8_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data[8]; __builtin_amdgcn_sched_barrier(0); wait_vmcnt<8>(); data[0].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 0)); wait_vmcnt<7>(); data[1].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 1)); wait_vmcnt<6>(); data[2].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 2)); wait_vmcnt<5>(); data[3].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 3)); wait_vmcnt<4>(); data[4].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 4)); wait_vmcnt<3>(); data[5].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 5)); wait_vmcnt<2>(); data[6].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 6)); wait_vmcnt<1>(); data[7].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 7)); __builtin_amdgcn_sched_barrier(0); #pragma unroll for (int k_idx = 0; k_idx < 8; k_idx++) { #pragma unroll for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx].fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { // auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); // auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); // auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); // auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); // tCrB(j, 0, k_idx) = f1; // tCrB(j + 1, 0, k_idx) = f2; // tCrB(j + 2, 0, k_idx) = f3; // tCrB(j + 3, 0, k_idx) = f4; __hip_fp8x4_storage_t fp8_data = data[k_idx].fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; } ; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, 0, k_idx) = result0[0]; tCrB(j + 1, 0, k_idx) = result0[1]; tCrB(j + 2, 0, k_idx) = result1[0]; tCrB(j + 3, 0, k_idx) = result1[1]; } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } } #if 1 template __forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, TiledMma tiled_mma, uint32x4_t& _data, const float& k_scale) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; union { uint32x4_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data; data.data_128 = _data; for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data.fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data.fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } // if (thread0()) { // printf(" static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8) = %x f1 = %.2f\n", static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), f1); // } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, 0, k_idx) = rst0; tCrB(j + 1, 0, k_idx) = rst1; tCrB(j + 2, 0, k_idx) = rst2; tCrB(j + 3, 0, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { // auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); // auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); // auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); // auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); // tCrB(j, 0, k_idx) = f1; // tCrB(j + 1, 0, k_idx) = f2; // tCrB(j + 2, 0, k_idx) = f3; // tCrB(j + 3, 0, k_idx) = f4; __hip_fp8x4_storage_t fp8_data = data.fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; } ; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, 0, k_idx) = result0[0]; tCrB(j + 1, 0, k_idx) = result0[1]; tCrB(j + 2, 0, k_idx) = result1[0]; tCrB(j + 3, 0, k_idx) = result1[1]; } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } #else #endif template __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B, const float& k_scale ) { typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef unsigned int __hip_fp8x4_storage_t; typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned char __hip_fp8_storage_t; auto lds = reinterpret_cast<__fp16 *>(&tCsB(0, 0, 0)); auto layout = tCsB.layout(); union { __fp16x8_t data_128; __hip_fp8x4_storage_t fp8_array[4]; } data[8]; constexpr short offset0 = layout(0, 0, 0) * 2; data[0].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset0); constexpr short offset1 = layout(0, 1, 0) * 2; data[1].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset1); constexpr short offset2 = layout(0, 0, 1) * 2; data[2].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset2); constexpr short offset3 = layout(0, 1, 1) * 2; data[3].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset3); constexpr short offset4 = layout(0, 0, 2) * 2; data[4].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset4); constexpr short offset5 = layout(0, 1, 2) * 2; data[5].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset5); constexpr short offset6 = layout(0, 0, 3) * 2; data[6].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset6); constexpr short offset7 = layout(0, 1, 3) * 2; data[7].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset7); #pragma unroll for (int k_idx = 0; k_idx < 4; k_idx++) { #pragma unroll for (int i = 0; i < 2; i++) { #pragma unroll for (int j = 0; j < 16; j+=4) { auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx * 2 + i].fp8_array[j / 4]); auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx * 2 + i].fp8_array[j / 4])) + 1); if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v) { // cutlass::NumericConverter convert_; auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); auto rst0 = fp8e4m3_to_bf16(f1); auto rst1 = fp8e4m3_to_bf16(f2); auto rst2 = fp8e4m3_to_bf16(f3); auto rst3 = fp8e4m3_to_bf16(f4); tCrB(j, i, k_idx) = rst0; tCrB(j + 1, i, k_idx) = rst1; tCrB(j + 2, i, k_idx) = rst2; tCrB(j + 3, i, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } cutlass::NumericConverter convert_; auto rst0 = convert_(f1); auto rst1 = convert_(f2); auto rst2 = convert_(f3); auto rst3 = convert_(f4); tCrB(j, i, k_idx) = rst0; tCrB(j + 1, i, k_idx) = rst1; tCrB(j + 2, i, k_idx) = rst2; tCrB(j + 3, i, k_idx) = rst3; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v) { // auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); // auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); // auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); // auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); // tCrB(j, i, k_idx) = f1; // tCrB(j + 1, i, k_idx) = f2; // tCrB(j + 2, i, k_idx) = f3; // tCrB(j + 3, i, k_idx) = f4; __hip_fp8x4_storage_t fp8_data = data[k_idx * 2 + i].fp8_array[j / 4]; union Fp8_data_union{ __hip_fp8x4_storage_t fp8x4; uint16_t fp16[2]; } ; Fp8_data_union first_fp8, last_fp8; first_fp8.fp8x4 = ((fp8_data & 0xff00ff00)); last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8); tCrB(j, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]); tCrB(j + 1, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);; tCrB(j + 2, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);; tCrB(j + 3, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;; } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v){ auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8)); auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8))); auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8)); auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8)); if constexpr(!is_scale_equal_one) { f1 *= k_scale; f2 *= k_scale; f3 *= k_scale; f4 *= k_scale; } auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f3); auto rst1 = __builtin_amdgcn_cvt_pkrtz(f2, f4); cutlass::Array result0 = reinterpret_cast &>(rst0); cutlass::Array result1 = reinterpret_cast &>(rst1); tCrB(j, i, k_idx) = result0[0]; tCrB(j + 1, i, k_idx) = result1[0]; tCrB(j + 2, i, k_idx) = result0[1]; tCrB(j + 3, i, k_idx) = result1[1]; } } } cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } } template __forceinline__ __device__ void gemm_k_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B, int k_idx) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N // __builtin_amdgcn_sched_barrier(0); // __builtin_amdgcn_s_setprio(0); // __builtin_amdgcn_sched_barrier(0); // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) // { // printf(" %d %p\n", threadIdx.x, &tCsB(0, 0, k_idx)); // } cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx)); // if (block0()) // { // printf("thrid %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", threadIdx.x, float(tCrB_copy_view(0, 0, 0)), float(tCrB_copy_view(1, 0, 0)), float(tCrB_copy_view(2, 0, 0)), // float(tCrB_copy_view(3, 0, 0)), float(tCrB_copy_view(4, 0, 0)), float(tCrB_copy_view(5, 0, 0)), float(tCrB_copy_view(6, 0, 0)), float(tCrB_copy_view(7, 0, 0)) // ); // } // __builtin_amdgcn_sched_barrier(0); // __builtin_amdgcn_s_setprio(1); // __builtin_amdgcn_sched_barrier(0); cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8_q_tp1( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0 ) { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int mma_k = 64*64; int row = tidx % 16 + (warp_id % 4) * 16; int col = (lane / 16) * 16 + (warp_id / 4) * 64 + k_idx * 128; int offset_v = row * row_stride + (col) * element_size; // bytes if (!Is_even_MN && row >= max_MN) offset_v = -1; if (!Is_even_K && col >= 576) offset_v = -1; //int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64; #if defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8_q_tp4( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0 ) { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int mma_k = 64*64; int row = tidx % 16 + (warp_id % 2) * 16; int col = (lane / 16) * 16 + (warp_id / 2) * 64 + k_idx * 128; int offset_v = row * row_stride + (col) * element_size; // bytes if (!Is_even_MN && row >= max_MN) offset_v = -1; if (!Is_even_K && col >= 576) offset_v = -1; //int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id % 2) * bytes_per_warp + (k_idx ) * 32*128 * element_size + (warp_id / 2) * 32 * 64; #if defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } template < bool Is_even_MN=true, bool Is_even_K=true, bool Is_load_Q=false, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8_tp1( Tensor const& src, Tensor & dst, int k_idx_, const int row_stride, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; // global addr // uint32x4_t global_addr = {0}; // *(uint64_t*)&global_addr = reinterpret_cast(src.data().get()); // global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride // global_addr[2] = 0xfffffffe; // global_addr[3] = 0x00020000; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int mma_k = 64*64; // int row = lane / 4; // int col = lane % 4; // int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4); // 此处待优化,后8行,行号需要交换 int virtual_row = lane / 8;//0 int virtual_col = lane % 8;//0 int swizzle_col = virtual_row ^ virtual_col; int row = lane / 4;//0 // 8->9 9->8 row = (row >= 8 ) ^ row; // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row; int col = swizzle_col % 4; int row_offset = row + ((warp_id % 4) * 16) ; int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64; int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1; //int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64; #if defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } // template // __forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, // Tensor const& sAcc) // { // using Value_type = typename Engine0::value_type; // int tidx = threadIdx.x; // auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx); // Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP); // Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc); // cute::copy(smem_tiled_copy_ACC, taccOr, taccOs); // // asm volatile("s_waitcnt lgkmcnt(0)\n\t"); // __syncthreads(); // // wangaq debug // // if (tidx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // // int col = 8; // // for (int i = 0; i < 16*64/col; ++i) { // // printf("sP:%d ", i); // // for (int j = 0; j < col; ++j) { // // printf("%10.4f ", float(sAcc(i*col+j))); // // } // // printf("\n"); // // } // // } // auto thr_mma = tiled_mma_o.get_thread_slice(tidx); // auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, tiled_mma_o); // auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); // Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc); // Tensor tSrACC = thr_mma.partition_fragment_A(sAcc); // Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC); // cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view); // return tSrACC; // } template __forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc) { using Value_type = typename Engine0::value_type; int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; // __fp16 *smem_ptr = // sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0); // sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0); __syncthreads(); using SmemLayoutAtomP = Layout, Int<64>>, Stride, _1>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape, Int<64>>{})); Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{}); auto thr_mma = tiled_mma_o.get_thread_slice(tid); Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp); // for (int i = 0; i < 4; i++) // { // tSrACC(i, 0, 0) = sAcc(tid * 8 + i); // tSrACC(i, 0, 1) = sAcc(tid * 8 + i + 4); // tSrACC(i, 0, 2) = sAcc(tid * 8 + i + 16 * 32); // tSrACC(i, 0, 3) = sAcc(tid * 8 + i + 16 * 32 + 4); // } tSrACC(0, 0, 0) = sAcc(tid * 8 + 0); tSrACC(1, 0, 0) = sAcc(tid * 8 + 1); tSrACC(2, 0, 0) = sAcc(tid * 8 + 2); tSrACC(3, 0, 0) = sAcc(tid * 8 + 3); tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4); tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4); tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4); tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4); tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32); tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32); tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32); tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32); tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32); tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32); tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32); tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32); // tSrACC(i, 0, 1) = sAcc(tid * 8 + i + 4); // tSrACC(i, 0, 2) = sAcc(tid * 8 + i + 16 * 32); // tSrACC(i, 0, 3) = sAcc(tid * 8 + i + 16 * 32 + 4); // tSrACC(1, 0, 0) = sAcc(tid * 8); // for (int k = 0; k < 4; k++) // { // tSrACC(0, 0, k) = sAcc(k * 16 * 16 + tid * 4); // tSrACC(1, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 1); // tSrACC(2, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 2); // tSrACC(3, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 3); // } // for (int k = 0; k < 4; k++) // { // tSrACC(0, 0, k) = sAcc(k * 16 * 16 + tid * 4); // tSrACC(1, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 1); // tSrACC(2, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 2); // tSrACC(3, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 3); // } // auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom{}, tiled_mma); // auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx); // Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP); // Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc); // cute::copy(smem_tiled_copy_ACC, taccOr, taccOs); // // asm volatile("s_waitcnt lgkmcnt(0)\n\t"); // __syncthreads(); // wangaq debug // if (tidx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // int col = 8; // for (int i = 0; i < 16*64/col; ++i) { // printf("sP:%d ", i); // for (int j = 0; j < col; ++j) { // printf("%10.4f ", float(sAcc(i*col+j))); // } // printf("\n"); // } // } // auto thr_mma = tiled_mma_o.get_thread_slice(tidx); // auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, tiled_mma_o); // auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); // Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc); // Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp); // Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC); // cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view); return tSrACC; } #if 0 template __forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc) { using Value_type = typename Engine0::value_type; int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 0 + warp_id * 32 * 8) = tOrP(0, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 1 + warp_id * 32 * 8) = tOrP(1, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 2 + warp_id * 32 * 8) = tOrP(2, 0, 0); sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 3 + warp_id * 32 * 8) = tOrP(3, 0, 0); __syncthreads(); using SmemLayoutAtomP = Layout, Int<64>>, Stride, _1>>; using SmemLayoutP = decltype(tile_to_shape( SmemLayoutAtomP{}, Shape, Int<64>>{})); Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{}); auto thr_mma = tiled_mma_o.get_thread_slice(tid); Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp); tSrACC(0, 0, 0) = sAcc(tid * 8 + 0); tSrACC(1, 0, 0) = sAcc(tid * 8 + 1); tSrACC(2, 0, 0) = sAcc(tid * 8 + 2); tSrACC(3, 0, 0) = sAcc(tid * 8 + 3); tSrACC(4, 0, 0) = sAcc(tid * 8 + 4); tSrACC(5, 0, 0) = sAcc(tid * 8 + 5); tSrACC(6, 0, 0) = sAcc(tid * 8 + 6); tSrACC(7, 0, 0) = sAcc(tid * 8 + 7); tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 16 * 32); tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 16 * 32); tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 16 * 32); tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 16 * 32); tSrACC(4, 0, 1) = sAcc(tid * 8 + 4 + 16 * 32); tSrACC(5, 0, 1) = sAcc(tid * 8 + 5 + 16 * 32); tSrACC(6, 0, 1) = sAcc(tid * 8 + 6 + 16 * 32); tSrACC(7, 0, 1) = sAcc(tid * 8 + 7 + 16 * 32); return tSrACC; } #else template __forceinline__ __device__ void convert_layout_acc_Aregs_fp8(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor const& tOrP, Tensor const& sAcc, intx4_t &data) { using Value_type = typename Engine0::value_type; int tid = threadIdx.x % 64; int warp_id = threadIdx.x / 64; sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0] = tOrP(0, 0, 0); sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 1] = tOrP(1, 0, 0); sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 2] = tOrP(2, 0, 0); sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 3] = tOrP(3, 0, 0); __syncthreads(); data = *reinterpret_cast(&(sAcc[tid * 16])); } #endif #if 0 template __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Tensor1& dst) { auto lds = reinterpret_cast(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 1; auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset); uint8_t * d_ptr = reinterpret_cast(&d); uint8_t * dst_ptr = reinterpret_cast(&(dst(0, r_row, col))); dst_ptr[0] = d_ptr[0]; dst_ptr[1] = d_ptr[1]; dst_ptr[2] = d_ptr[2]; dst_ptr[3] = d_ptr[3]; dst_ptr[4] = d_ptr[4]; dst_ptr[5] = d_ptr[5]; dst_ptr[6] = d_ptr[6]; dst_ptr[7] = d_ptr[7]; dst_ptr[8] = d_ptr[8]; dst_ptr[9] = d_ptr[9]; dst_ptr[10] = d_ptr[10]; dst_ptr[11] = d_ptr[11]; dst_ptr[12] = d_ptr[12]; dst_ptr[13] = d_ptr[13]; dst_ptr[14] = d_ptr[14]; dst_ptr[15] = d_ptr[15]; } #else template __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst) { #if 0 auto lds = reinterpret_cast(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 1; auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset); dst = d; #else auto lds = reinterpret_cast(src.data().get()); auto layout = src.layout(); constexpr short offset = layout(0, row, col) * 1; lds += offset; dst = __builtin_hcu_ds_read_m32x32_i8_alt2((__attribute__((address_space(3))) int*)(lds)); #endif } #endif /* 原来的 exp2f 对于极小数有特殊处理, 对于小于 -126 的输入 x , exp2f 计算方式是 2^(x + 64) * 2^{-64} 但是对于深度学习来说, 2^-126 的数字其实没那么重要了, 因此只需要保留 v_exp_f32 直接暴力计算即可 */ extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32"); __device__ inline uint32x4_t make_rscr(unsigned char* ptr, const int stride, const int zero_pad) { uint32x4_t rscr; *(uint64_t*)&rscr = (reinterpret_cast(ptr)); rscr[2] = stride; rscr[3] = (1 << 16) & 0XFFFFFFFF; rscr[3] |= (zero_pad) << 8; return rscr; } template < class SrcEngine, class SrcLayout, class DstEngine, class DstLayout> CUTE_HOST_DEVICE void lds_direct_copy_qkvfp8_zero_lds( Tensor const& src, Tensor & dst, int k_idx_) { constexpr int warp_size = 64; int tidx = threadIdx.x;//0-256 int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size;//0-63 constexpr int element_size = 1; int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576 const int offset_s = 0; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = 0x80000000; global_addr[3] = 0x00020000; constexpr int elements_per_thread = 16; constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1 int offset_v=-1; int ldsAddrPerWave = reinterpret_cast(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64; #if defined(__gfx938__) __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); #endif } template < bool Is_even_MN=true, bool Is_even_K=true, int k_idx, class SrcEngine, class SrcLayout > CUTE_HOST_DEVICE void buffer_load_copy_fp8_tp1( Tensor const& src, intx4_t & dst, const int row_stride, const int max_MN=0) { constexpr int warp_size = 64; int tidx = threadIdx.x; int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int lane = tidx % warp_size; constexpr int element_size = 1; constexpr int elements_per_thread = 16; { struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr; *(uint64_t*)&glob_ptr = reinterpret_cast(src.data().get()); // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride glob_ptr.latter |= ((row_stride) << 16); uint32x4_t global_addr = {0}; global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[2] = !Is_even_MN ? max_MN : 0x80000000; global_addr[3] = 0x00020000; int mma_k = 32*64; int row = lane / 4; int col = lane % 4; int row_offset = row + ((warp_id % 4) * 16) ; int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64; // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // if (!Is_even_K && col_offset >=576) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; { dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash