Commit 902e3032 authored by zhanghj2's avatar zhanghj2
Browse files

Fix precision issue​

parent b4f69d84
...@@ -109,9 +109,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -109,9 +109,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4; // int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;; int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32; // int col_offset = col * 8 + warp_idx * 32;
if (HAVE_TOPK_LENGTH && row_offset >= topk_length) {
return params.s_kv;
}
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else { } else {
...@@ -327,10 +324,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -327,10 +324,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
accs_f32[i].w = 0.0f; accs_f32[i].w = 0.0f;
} }
auto [row_offset, col] = calc_row_and_col_k(block_idx); auto [row_offset, col] = calc_row_and_col_k(block_idx);
const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN; // const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) { // if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) {
row_offset = -1; // row_offset = -1;
} // }
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1 #if 1
...@@ -649,6 +646,15 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ ...@@ -649,6 +646,15 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1<false>(acco_f32, params.sm_scale); Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1<false>(acco_f32, params.sm_scale);
// if (block0())
// {
// printf(" threadIdx.x %d %.3f %.3f %.3f %.3f \n", threadIdx.x,
// acco_f32[0].x,
// acco_f32[0].y,
// acco_f32[0].z,
// acco_f32[0].w
// );
// }
const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v; const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o), Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{}, Shape<Int<kBlockM>, Int<kHeadDimV>>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment