Commit 40f4bf39 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp1

parent 5e577dee
...@@ -1555,6 +1555,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1555,6 +1555,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
buffer_load_copy_fp8_tp1<true, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1<true, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1<true, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1<true, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
if (warp_id < 4)
buffer_load_copy_fp8_tp1<true, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - block_idx * kBlockN); buffer_load_copy_fp8_tp1<true, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
} }
...@@ -1598,6 +1599,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1598,6 +1599,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
kv_lds_write_ptr += 64 * 128; kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data; *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data;
kv_lds_write_ptr += 64 * 128; kv_lds_write_ptr += 64 * 128;
if (warp_id < 4)
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data; *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data;
} }
// asm volatile("s_barrier \n\t"); // asm volatile("s_barrier \n\t");
......
...@@ -2809,11 +2809,12 @@ buffer_load_copy_fp8_tp1( ...@@ -2809,11 +2809,12 @@ buffer_load_copy_fp8_tp1(
PtrWrapper glob_ptr; PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get()); *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride // glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
uint32x4_t global_addr = {0}; uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former); global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter); global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000; global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000; global_addr[3] = 0x00020000;
int mma_k = 32*64; int mma_k = 32*64;
...@@ -2821,12 +2822,12 @@ buffer_load_copy_fp8_tp1( ...@@ -2821,12 +2822,12 @@ buffer_load_copy_fp8_tp1(
int col = lane % 4; int col = lane % 4;
int row_offset = row + ((warp_id % 4) * 16) ; int row_offset = row + ((warp_id % 4) * 16) ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64; int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes // int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >=576) offset_v = -1; // if (!Is_even_K && col_offset >=576) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; // if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
{ {
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment