Commit b4f69d84 authored by zhanghj2's avatar zhanghj2
Browse files

opt h_q 128 sparse prefill

parent e83a4119
...@@ -98,8 +98,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -98,8 +98,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else { } else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
} }
...@@ -111,12 +109,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -111,12 +109,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4; // int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;; int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32; // int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048) { if (HAVE_TOPK_LENGTH && row_offset >= topk_length) {
row_offset = sIndices[row_offset % 1024]; return params.s_kv;
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} }
else { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024];
} else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
} }
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
...@@ -183,6 +181,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -183,6 +181,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) { auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) {
constexpr int element_size = 2; constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8; constexpr int elements_per_thread = 8;
int col_offset = col; int col_offset = col;
int offset_v = col_offset * 2; int offset_v = col_offset * 2;
...@@ -208,6 +220,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -208,6 +220,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) { auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) {
constexpr int element_size = 2; constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8; constexpr int elements_per_thread = 8;
int col_offset = col; int col_offset = col;
// int v_idx = row_offset; // int v_idx = row_offset;
...@@ -302,9 +328,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -302,9 +328,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
} }
auto [row_offset, col] = calc_row_and_col_k(block_idx); auto [row_offset, col] = calc_row_and_col_k(block_idx);
const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN; const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
if (!IS_TOPK_2048 && HAVE_TOPK_LENGTH && block_idx == num_topk_blocks - 1 && row_in_topk >= topk_length) { if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) {
row_offset = -1; row_offset = -1;
} }
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1 #if 1
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
...@@ -370,10 +397,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -370,10 +397,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16; const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN; int offs = n_idx + block_idx * kBlockN;
int t; int t;
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
t = sIndices[offs % 1024]; t = sIndices[offs % 1024];
} else if constexpr (CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset];
} else { } else {
t = gIndices[offs]; t = gIndices[offs];
} }
...@@ -1272,7 +1297,7 @@ static void run_h64_fast_path(const SparseAttnFwdParams& params) { ...@@ -1272,7 +1297,7 @@ static void run_h64_fast_path(const SparseAttnFwdParams& params) {
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
if (params.h_q == 64) { if (params.h_q == 64 || params.h_q == 128) {
if (params.attn_sink) { if (params.attn_sink) {
run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, true>(params); run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, true>(params);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment