Commit 87682fe2 authored by zhaochao's avatar zhaochao
Browse files

[DCU]Skip configurations that FlashAttention does not support


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent 9d34e27a
...@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla = { model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1": ModelConfig( # "mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 # 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1 # ), # cross, 1
"mla_2_2": ModelConfig( # "mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 # 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1 # ), # cross, 1
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment