Commit ae7d4f11 authored by shenzhe's avatar shenzhe
Browse files

Fix phase1 LDS address scalarization

parent 929ccc23
......@@ -144,6 +144,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
global_addr_indices[3] = 0x00020000;
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4;
ldsAddrPerWave = __builtin_amdgcn_readfirstlane(ldsAddrPerWave);
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_s = n * 1024 * 4;
const int first_index = warp_idx * 256 + lane_idx * 4;
......@@ -197,6 +198,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2;
ldsAddrPerWave = __builtin_amdgcn_readfirstlane(ldsAddrPerWave);
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
......@@ -237,6 +239,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
ldsAddrPerWave = __builtin_amdgcn_readfirstlane(ldsAddrPerWave);
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment