Commit 79c06a56 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz buffer load提升fp8性能

parent 276a1fb7
......@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8(
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
......@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8(
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = col_offset;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
// asm volatile(
// "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
// " \n\t" :"=v"(dst),
// "+v"(offset_v), "+s"(global_addr)
// );
}
else {
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
}
} else {
......@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8(
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
constexpr int elements_per_thread = 16;
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
......@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8(
int col = lane % 4;
int row_offset = row + (warp_id * 16);
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
}
}
......@@ -737,48 +740,48 @@ lds_direct_copy_qkvfp8(
if constexpr (Is_load_Q) {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment