Commit ca2958a8 authored by zhaochao's avatar zhaochao
Browse files

[DCU]Fix the dimension bug in the MLA under the FlashAttention backend.


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent 565fd629
......@@ -216,6 +216,9 @@ def test_dot_product_attention(
# FlashAttention backend
if flash_attn_supported:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION and config.head_dim_qk < config.head_dim_v:
pytest.skip("FlashAttention on ROCm does not support MLA with head_dim_qk < head_dim_v")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
......
......@@ -890,6 +890,13 @@ class FlashAttention(torch.nn.Module):
elif q_format == "thd":
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)
if value_layer.shape[-1] != query_layer.shape[-1]:
v_dim = value_layer.shape[-1]
num_heads = query_layer.shape[-2]
# 恢复为 (..., num_heads, head_dim_qk)
out_shape_heads = output.shape[:-1] + (num_heads, query_layer.shape[-1])
output = output.view(out_shape_heads)[..., :v_dim] # 裁剪到 V 的维度
output = output.reshape(output.shape[:-2] + (num_heads * v_dim,))
return output.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment