Commit 3722ec71 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp1

parent 34489f46
This diff is collapsed.
......@@ -2748,6 +2748,59 @@ lds_direct_copy_qkvfp8_zero_lds(
#endif
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
int k_idx,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8_tp1(
Tensor<SrcEngine, SrcLayout> const& src,
intx4_t & dst,
const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
constexpr int elements_per_thread = 16;
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = lane / 4;
int col = lane % 4;
int row_offset = row + ((warp_id % 4) * 16) ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >=576) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
{
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment