Commit 72b2aea0 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz tp8

parent 6d3ed1da
...@@ -863,7 +863,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -863,7 +863,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
constexpr static int STAGE = 8; constexpr static int STAGE = 8;
#if 1 #if 1
uint8_t* kv_lds_write_ptr_base = reinterpret_cast<uint8_t*>(&tSsK(0, 0 ,0)); extern __shared__ char shared_memory[];
int row_ = lane_idx / 8;
int col_ = lane_idx % 8;
int swizzle_col_ = row_ ^ col_;
uint8_t* kv_lds_write_ptr_base = reinterpret_cast<uint8_t*>(shared_memory) +
row_ * 128 + swizzle_col_ * 16 + warp_idx * 16 * 64;
v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1; v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1;
c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f; c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f;
...@@ -877,7 +882,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -877,7 +882,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f; c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f;
c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f; c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f;
extern __shared__ char shared_memory[];
struct IsMaskBlock {}; struct IsMaskBlock {};
struct IsFirstMaskBlock {}; struct IsFirstMaskBlock {};
struct IsNoMaskBlock {}; struct IsNoMaskBlock {};
......
...@@ -642,8 +642,8 @@ buffer_load_copy_qkvfp8( ...@@ -642,8 +642,8 @@ buffer_load_copy_qkvfp8(
global_addr[3] = 0x00020000; global_addr[3] = 0x00020000;
int mma_k = 32*64; int mma_k = 32*64;
int row = lane % 16; int row = lane / 4;
int col = lane / 16; int col = lane % 4;
int row_offset = row + (warp_id * 16); int row_offset = row + (warp_id * 16);
int col_offset = col * elements_per_thread + k_idx * 64; int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment