Commit 40f4bf39 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp1

parent 5e577dee
......@@ -1555,6 +1555,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
buffer_load_copy_fp8_tp1<true, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
if (warp_id < 4)
buffer_load_copy_fp8_tp1<true, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
}
......@@ -1598,6 +1599,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data;
kv_lds_write_ptr += 64 * 128;
if (warp_id < 4)
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data;
}
// asm volatile("s_barrier \n\t");
......
......@@ -2809,11 +2809,12 @@ buffer_load_copy_fp8_tp1(
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16);
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[2] = !Is_even_MN ? max_MN : 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
......@@ -2821,12 +2822,12 @@ buffer_load_copy_fp8_tp1(
int col = lane % 4;
int row_offset = row + ((warp_id % 4) * 16) ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >=576) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >=576) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
{
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, row_offset, col_offset, false, false);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment