Commit 98b7c697 authored by zhanghj2's avatar zhanghj2
Browse files

fp8 tp1性能提升

parent 24c52aee
This diff is collapsed.
...@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int ...@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32"); extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32");
__device__ inline uint32x4_t make_rscr(unsigned char* ptr, const int stride, const int zero_pad) {
uint32x4_t rscr;
*(uint64_t*)&rscr = (reinterpret_cast<uint64_t>(ptr));
rscr[2] = stride;
rscr[3] = (1 << 16) & 0XFFFFFFFF;
rscr[3] |= (zero_pad) << 8;
return rscr;
}
template <
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_zero_lds(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int offset_v=-1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash } // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment