Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ca2958a8
Commit
ca2958a8
authored
Oct 23, 2025
by
zhaochao
Browse files
[DCU]Fix the dimension bug in the MLA under the FlashAttention backend.
Signed-off-by:
zhaochao
<
zhaochao1@sugon.com
>
parent
565fd629
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
0 deletions
+10
-0
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+3
-0
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+7
-0
No files found.
tests/pytorch/attention/test_attention.py
View file @
ca2958a8
...
...
@@ -216,6 +216,9 @@ def test_dot_product_attention(
# FlashAttention backend
if
flash_attn_supported
:
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
and
config
.
head_dim_qk
<
config
.
head_dim_v
:
pytest
.
skip
(
"FlashAttention on ROCm does not support MLA with head_dim_qk < head_dim_v"
)
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
config
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
ca2958a8
...
...
@@ -890,6 +890,13 @@ class FlashAttention(torch.nn.Module):
elif
q_format
==
"thd"
:
# thd -> t(hd)
output
=
output
.
reshape
(
output
.
shape
[
0
],
-
1
)
if
value_layer
.
shape
[
-
1
]
!=
query_layer
.
shape
[
-
1
]:
v_dim
=
value_layer
.
shape
[
-
1
]
num_heads
=
query_layer
.
shape
[
-
2
]
# 恢复为 (..., num_heads, head_dim_qk)
out_shape_heads
=
output
.
shape
[:
-
1
]
+
(
num_heads
,
query_layer
.
shape
[
-
1
])
output
=
output
.
view
(
out_shape_heads
)[...,
:
v_dim
]
# 裁剪到 V 的维度
output
=
output
.
reshape
(
output
.
shape
[:
-
2
]
+
(
num_heads
*
v_dim
,))
return
output
.
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment