Unverified Commit 5bb3a412 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Add the missing 1HSS tests (#1052)



Add the missing 1HSS tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent d74e65f5
...@@ -295,7 +295,10 @@ class FusedAttnRunner: ...@@ -295,7 +295,10 @@ class FusedAttnRunner:
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
):
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip( pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for " "B1SS, BHSS and 11SS bias shapes are only supported for "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment