Commit 9d3116ad authored by Tri Dao's avatar Tri Dao
Browse files

Don't enforce bitwise consistency for dq in race condition test

Since we could be parallelizing over seqlen_k
parent 7c995381
......@@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g)
# Parallelizing over seqlen_k makes dq non-deterministic
deterministic_dq = False
# Numerical error if we just do any arithmetic on dq
dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
for _ in range(10):
torch.random.manual_seed(0)
......@@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g)
assert torch.equal(dq_unpad, dq_unpad_0)
assert equal_fn(dq_unpad, dq_unpad_0)
assert torch.equal(dk_unpad, dk_unpad_0)
assert torch.equal(dv_unpad, dv_unpad_0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment