Unverified Commit 8b681d77 authored by Tianxing Wu's avatar Tianxing Wu Committed by GitHub
Browse files

[Rocm] Fix to the rocm_mla_decode_rope.py returning random result (#3898)

parent 194eea17
...@@ -230,7 +230,7 @@ def _fwd_grouped_kernel_stage1_rope( ...@@ -230,7 +230,7 @@ def _fwd_grouped_kernel_stage1_rope(
other=0.0, other=0.0,
) # positional embedding part of keys ) # positional embedding part of keys
if USE_ROPE and start_n >= cur_batch_seq_len - BLOCK_N: if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N:
k_pe = tl.where( k_pe = tl.where(
offs_n[None, :] != (split_kv_end - 1), offs_n[None, :] != (split_kv_end - 1),
k_pe, k_pe,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment