Commit 9a805181 authored by zhanghj2's avatar zhanghj2
Browse files

Clean up redundant code​

parent bd0083f5
......@@ -384,63 +384,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
}
#undef LOAD_K_AND_QK_GEMM
#else
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val); \
buffer_load_lds_k(row_offset, col, k_val + 1); \
buffer_load_lds_k(row_offset, col, k_val + 2); \
buffer_load_lds_k(row_offset, col, k_val + 3); \
buffer_load_lds_k(row_offset, col, k_val + 4); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 1>(q_reg[k_val + 1].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 2>(q_reg[k_val + 2].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 3>(q_reg[k_val + 3].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 4>(q_reg[k_val + 4].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_barrier \n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
LOAD_K_AND_QK_GEMM(0);
LOAD_K_AND_QK_GEMM(5);
LOAD_K_AND_QK_GEMM(10);
{
constexpr int k_val = (15);
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val + 1);
buffer_load_lds_k(row_offset, col, k_val + 2);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val + 1>(q_reg[k_val + 1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val + 2>(q_reg[k_val + 2].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
}
#endif
auto is_valid_token = [&](const int idx) -> bool {
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
......@@ -697,15 +640,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1<false>(acco_f32, params.sm_scale);
// if (block0())
// {
// printf(" threadIdx.x %d %.3f %.3f %.3f %.3f \n", threadIdx.x,
// acco_f32[0].x,
// acco_f32[0].y,
// acco_f32[0].z,
// acco_f32[0].w
// );
// }
const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment