Commit 902e3032 authored by zhanghj2's avatar zhanghj2
Browse files

Fix precision issue​

parent b4f69d84
......@@ -109,9 +109,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32;
if (HAVE_TOPK_LENGTH && row_offset >= topk_length) {
return params.s_kv;
}
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024];
} else {
......@@ -327,10 +324,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
accs_f32[i].w = 0.0f;
}
auto [row_offset, col] = calc_row_and_col_k(block_idx);
const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) {
row_offset = -1;
}
// const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN;
// if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) {
// row_offset = -1;
// }
row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1
......@@ -649,6 +646,15 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1<false>(acco_f32, params.sm_scale);
// if (block0())
// {
// printf(" threadIdx.x %d %.3f %.3f %.3f %.3f \n", threadIdx.x,
// acco_f32[0].x,
// acco_f32[0].y,
// acco_f32[0].z,
// acco_f32[0].w
// );
// }
const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment