Commit 79c06a56 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz buffer load提升fp8性能

parent 276a1fb7
...@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8( ...@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8(
PtrWrapper glob_ptr; PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get()); *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
uint32x4_t global_addr = {0}; uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000; global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000; global_addr[3] = 0x00020000;
int mma_k = 32*64; int mma_k = 32*64;
...@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8( ...@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8(
int col = lane / 16; int col = lane / 16;
int row_offset = row + (warp_id * 16) ; int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64; int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = col_offset;
if constexpr(use_asm) { if constexpr(use_asm) {
asm volatile( // asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" // "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), // " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) // "+v"(offset_v), "+s"(global_addr)
); // );
} }
else { else {
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
} }
} else { } else {
...@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8( ...@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8(
PtrWrapper glob_ptr; PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get()); *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
constexpr int elements_per_thread = 16; constexpr int elements_per_thread = 16;
uint32x4_t global_addr = {0}; uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000; global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000; global_addr[3] = 0x00020000;
int mma_k = 32*64; int mma_k = 32*64;
...@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8( ...@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8(
int col = lane % 4; int col = lane % 4;
int row_offset = row + (warp_id * 16); int row_offset = row + (warp_id * 16);
int col_offset = col * elements_per_thread + k_idx * 64; int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
} }
} }
...@@ -737,48 +740,48 @@ lds_direct_copy_qkvfp8( ...@@ -737,48 +740,48 @@ lds_direct_copy_qkvfp8(
if constexpr (Is_load_Q) { if constexpr (Is_load_Q) {
constexpr int warp_size = 64; constexpr int warp_size = 64;
int tidx = threadIdx.x; int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size); int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size; int lane = tidx % warp_size;
constexpr int element_size = 1; constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0; const int offset_s = 0;
struct PtrWrapper { struct PtrWrapper {
uint32_t former; uint32_t former;
uint32_t latter; uint32_t latter;
}; };
PtrWrapper glob_ptr; PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get()); *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
uint32x4_t global_addr = {0}; int row = lane % 16;
global_addr[0] = (glob_ptr.former); int col = lane / 16;
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000; int row_offset = row ;
global_addr[3] = 0x00020000; int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
if (!Is_even_K && col_offset >= 576) offset_v = -1; asm volatile(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; "s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
asm volatile( "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
"s_mov_b32 m0, %1 \n\t" :);
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment