Commit e83a4119 authored by zhanghj2's avatar zhanghj2
Browse files

tail guard to opt sparse prefill

parent 9a805181
......@@ -98,6 +98,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else {
row_offset = gIndices[row_offset];
}
......@@ -109,9 +111,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
if constexpr (IS_TOPK_2048) {
row_offset = sIndices[row_offset % 1024];
} else {
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
}
else {
row_offset = gIndices[row_offset];
}
row_offset = row_offset == -1 ? params.s_kv : row_offset;
......@@ -178,20 +183,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) {
constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
int offset_v = col_offset * 2;
......@@ -217,20 +208,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) {
constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
// int v_idx = row_offset;
......@@ -324,6 +301,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
accs_f32[i].w = 0.0f;
}
auto [row_offset, col] = calc_row_and_col_k(block_idx);
const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
if (!IS_TOPK_2048 && HAVE_TOPK_LENGTH && block_idx == num_topk_blocks - 1 && row_in_topk >= topk_length) {
row_offset = -1;
}
row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1
#define LOAD_K_AND_QK_GEMM(k) \
......@@ -389,8 +370,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN;
int t;
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
if constexpr (IS_TOPK_2048) {
t = sIndices[offs % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else {
t = gIndices[offs];
}
......@@ -629,6 +612,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
}
else
{
if (num_topk_blocks > 0)
process_one_block(0, IsFirstBlock{});
for (int block_idx = 1; block_idx < num_topk_blocks; block_idx ++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment