Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
470010f5
Commit
470010f5
authored
Nov 03, 2022
by
Tri Dao
Browse files
Fix race condition for Triton bwd for headdim 48 and 96
parent
aacc10fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
22 deletions
+19
-22
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+19
-22
No files found.
flash_attn/flash_attn_triton.py
View file @
470010f5
...
...
@@ -4,21 +4,26 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
- Implement both causal and non-causal attention.
- Implement
cross
-attention
(not just self
-attention
)
.
- Implement
both self
-attention
and cross
-attention.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
Caution:
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
It is slightly
slower when headdim=128 and batch * nheads is large.
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
It is slightly
slower when headdim=128 and batch * nheads is large.
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
...
...
@@ -276,6 +281,7 @@ def _bwd_kernel_one_col_block(
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
# recompute p = softmax(qk, dim=-1).T
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
# Trying to combine the two masks seem to make the result wrong
if
not
EVEN_N
:
# Need to mask out otherwise the softmax is wrong
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
...
...
@@ -313,7 +319,7 @@ def _bwd_kernel_one_col_block(
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if
not
EVEN_M
:
if
not
(
EVEN_M
&
EVEN_HEADDIM
)
:
tl
.
debug_barrier
()
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# There's a race condition for headdim=48
...
...
@@ -329,16 +335,10 @@ def _bwd_kernel_one_col_block(
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
# compute dq
if
not
ATOMIC_ADD
:
if
EVEN_M
:
if
EVEN_HEADDIM
:
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
else
:
dq
=
tl
.
load
(
dq_ptrs
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
eviction_policy
=
"evict_last"
)
else
:
if
EVEN_HEADDIM
:
dq
=
tl
.
load
(
dq_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
,
...
...
@@ -356,11 +356,8 @@ def _bwd_kernel_one_col_block(
eviction_policy
=
"evict_last"
)
else
:
# If we're parallelizing across the seqlen_k dimension
dq
=
tl
.
dot
(
ds
,
k
)
if
EVEN_M
:
if
EVEN_HEADDIM
:
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
tl
.
atomic_add
(
dq_ptrs
,
dq
)
else
:
tl
.
atomic_add
(
dq_ptrs
,
dq
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
else
:
if
EVEN_HEADDIM
:
tl
.
atomic_add
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment