Commit e83a4119 authored by zhanghj2's avatar zhanghj2
Browse files

tail guard to opt sparse prefill

parent 9a805181
...@@ -98,6 +98,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -98,6 +98,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else { } else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
} }
...@@ -109,9 +111,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -109,9 +111,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4; // int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;; int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32; // int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { if constexpr (IS_TOPK_2048) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else { } else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
}
else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
} }
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
...@@ -178,20 +183,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -178,20 +183,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) { auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) {
constexpr int element_size = 2; constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8; constexpr int elements_per_thread = 8;
int col_offset = col; int col_offset = col;
int offset_v = col_offset * 2; int offset_v = col_offset * 2;
...@@ -217,20 +208,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -217,20 +208,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) { auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) {
constexpr int element_size = 2; constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8; constexpr int elements_per_thread = 8;
int col_offset = col; int col_offset = col;
// int v_idx = row_offset; // int v_idx = row_offset;
...@@ -324,6 +301,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -324,6 +301,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
accs_f32[i].w = 0.0f; accs_f32[i].w = 0.0f;
} }
auto [row_offset, col] = calc_row_and_col_k(block_idx); auto [row_offset, col] = calc_row_and_col_k(block_idx);
const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
if (!IS_TOPK_2048 && HAVE_TOPK_LENGTH && block_idx == num_topk_blocks - 1 && row_in_topk >= topk_length) {
row_offset = -1;
}
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1 #if 1
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
...@@ -389,8 +370,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -389,8 +370,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16; const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN; int offs = n_idx + block_idx * kBlockN;
int t; int t;
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { if constexpr (IS_TOPK_2048) {
t = sIndices[offs % 1024]; t = sIndices[offs % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else { } else {
t = gIndices[offs]; t = gIndices[offs];
} }
...@@ -629,6 +612,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -629,6 +612,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
} }
else else
{ {
if (num_topk_blocks > 0)
process_one_block(0, IsFirstBlock{}); process_one_block(0, IsFirstBlock{});
for (int block_idx = 1; block_idx < num_topk_blocks; block_idx ++) for (int block_idx = 1; block_idx < num_topk_blocks; block_idx ++)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment