Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
62025e1a
Commit
62025e1a
authored
Nov 04, 2022
by
Tri Dao
Browse files
Fix more race condition in Triton bwd when there's bias
parent
ff78ea41
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
0 deletions
+4
-0
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+4
-0
No files found.
flash_attn/flash_attn_triton.py
View file @
62025e1a
...
...
@@ -16,6 +16,8 @@ Changes:
small batch size * nheads.
Caution:
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
...
...
@@ -393,6 +395,8 @@ def _bwd_kernel_one_col_block(
# compute dk = dot(ds.T, q)
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
# compute dq
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
# Otherewise there's a race condition when BIAS_TYPE='matrix'
tl
.
debug_barrier
()
if
not
ATOMIC_ADD
:
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment