Commit 62025e1a authored by Tri Dao's avatar Tri Dao
Browse files

Fix more race condition in Triton bwd when there's bias

parent ff78ea41
......@@ -16,6 +16,8 @@ Changes:
small batch size * nheads.
Caution:
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
......@@ -393,6 +395,8 @@ def _bwd_kernel_one_col_block(
# compute dk = dot(ds.T, q)
dk += tl.dot(ds, q, trans_a=True)
# compute dq
if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
tl.debug_barrier()
if not ATOMIC_ADD:
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment