Commit d0787acc authored by Tri Dao's avatar Tri Dao
Browse files

Relax dropout_fraction test

parent dca6d89d
......@@ -1430,7 +1430,7 @@ def test_flash_attn_varlen_output(
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment