Unverified Commit e99e4673 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in...


[CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in lighting_attn.py (#29132)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent a42ab317
...@@ -198,7 +198,7 @@ def _fwd_kv_parallel( ...@@ -198,7 +198,7 @@ def _fwd_kv_parallel(
) )
# Load the decay factors for the current head and block # Load the decay factors for the current head and block
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :] k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
kv_index = tl.arange(0, CBLOCK) kv_index = tl.arange(0, CBLOCK)
...@@ -228,6 +228,12 @@ def _fwd_kv_parallel( ...@@ -228,6 +228,12 @@ def _fwd_kv_parallel(
# Load decay factor and compute weighted key-value outer product # Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr) k_decay = tl.load(k_decay_ptr)
# NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
# Please don't move it back until issue is resolved.
# Issue: https://github.com/ROCm/triton/issues/907
k_decay = k_decay[None, :]
kv += tl.dot(k_trans * k_decay, v) kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block # Move to the next sub-block
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment