"docs/source/en/conceptual/evaluation.md" did not exist on "f19f1287358beb31a71bc1bf0ef680a2c6155964"
Commit 9d3116ad authored by Tri Dao's avatar Tri Dao
Browse files

Don't enforce bitwise consistency for dq in race condition test

Since we could be parallelizing over seqlen_k
parent 7c995381
...@@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
g = torch.randn_like(output_unpad_0) g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g) (q_unpad, k_unpad, v_unpad), g)
# Parallelizing over seqlen_k makes dq non-deterministic
deterministic_dq = False
# Numerical error if we just do any arithmetic on dq
dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
for _ in range(10): for _ in range(10):
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if is_sm80 or d <= 64: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g) (q_unpad, k_unpad, v_unpad), g)
assert torch.equal(dq_unpad, dq_unpad_0) assert equal_fn(dq_unpad, dq_unpad_0)
assert torch.equal(dk_unpad, dk_unpad_0) assert torch.equal(dk_unpad, dk_unpad_0)
assert torch.equal(dv_unpad, dv_unpad_0) assert torch.equal(dv_unpad, dv_unpad_0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment