Unverified Commit d710c241 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Skip t3hd/th3d for MQA/GQA tests (#1293)



skip some t3hd/th3d tests for MQA/GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 933294dc
......@@ -644,6 +644,9 @@ model_configs_layout_thd = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment