Unverified Commit aa20d10a authored by zsolt-borbely-htec's avatar zsolt-borbely-htec Committed by GitHub
Browse files

[Misc] [ROCm] Prevent surplus tensor reshape (#19803)


Signed-off-by: default avatarZsolt Borbely <zsolt.borbely@htecgroup.com>
parent 2de12be4
......@@ -376,7 +376,7 @@ class TritonAttentionImpl(AttentionImpl):
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment