Commit 0ce8ee82 authored by zhanghj2's avatar zhanghj2
Browse files

fix 关闭attn sink情况下的错误

parent 200f01d5
...@@ -86,7 +86,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -86,7 +86,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : logf(sum_lse) + max_lse; float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : logf(sum_lse) + max_lse;
if (lane_idx == 0) if (lane_idx == 0)
gLse(warp_idx) = global_lse; gLse(warp_idx) = global_lse;
float o_scale = 0.0f; float o_scale = 1.0f;
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
float attn_sink = __ldg(params.attn_sink + q_head_idx); float attn_sink = __ldg(params.attn_sink + q_head_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment