Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87682fe2
Commit
87682fe2
authored
Nov 03, 2025
by
zhaochao
Browse files
[DCU]Skip configurations that FlashAttention does not support
Signed-off-by:
zhaochao
<
zhaochao1@sugon.com
>
parent
9d34e27a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
14 deletions
+15
-14
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+15
-14
No files found.
tests/pytorch/attention/test_attention.py
View file @
87682fe2
...
@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla
=
{
model_configs_mla
=
{
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# self , 0
#
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
#
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
#
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# self , 1
#
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1"
:
ModelConfig
(
#
"mla_2_1": ModelConfig(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
#
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
),
# cross, 1
#
), # cross, 1
"mla_2_2"
:
ModelConfig
(
#
"mla_2_2": ModelConfig(
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
#
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
),
# cross, 1
#
), # cross, 1
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
#
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment