Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ff78ea41
Commit
ff78ea41
authored
Nov 04, 2022
by
Tri Dao
Browse files
Fix race condition in Triton bwd when there's bias
parent
86862cfd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
1 deletion
+3
-1
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+1
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+2
-1
No files found.
flash_attn/flash_attn_triton.py
View file @
ff78ea41
...
@@ -326,6 +326,7 @@ def _bwd_kernel_one_col_block(
...
@@ -326,6 +326,7 @@ def _bwd_kernel_one_col_block(
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
if
BIAS_TYPE
!=
'none'
:
if
BIAS_TYPE
!=
'none'
:
tl
.
debug_barrier
()
# Race condition otherwise
if
BIAS_TYPE
==
'vector'
:
if
BIAS_TYPE
==
'vector'
:
if
EVEN_N
:
if
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
...
...
tests/test_flash_attn.py
View file @
ff78ea41
...
@@ -976,7 +976,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
...
@@ -976,7 +976,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
equal_fn
=
torch
.
equal
if
deterministic_dq
else
partial
(
torch
.
allclose
,
atol
=
dq_atol
)
equal_fn
=
torch
.
equal
if
deterministic_dq
else
partial
(
torch
.
allclose
,
atol
=
dq_atol
)
# Run 10000 times and check that the results don't change
# Run 10000 times and check that the results don't change
for
i
in
range
(
10000
):
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
None
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
...
@@ -986,6 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
...
@@ -986,6 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
dk_equal
=
torch
.
equal
(
dk
,
dk_0
)
dk_equal
=
torch
.
equal
(
dk
,
dk_0
)
dv_equal
=
torch
.
equal
(
dv
,
dv_0
)
dv_equal
=
torch
.
equal
(
dv
,
dv_0
)
if
not
(
dq_equal
and
dk_equal
and
dv_equal
):
if
not
(
dq_equal
and
dk_equal
and
dv_equal
):
print
(
f
'
{
i
=
}
'
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_0
).
abs
().
max
().
item
()
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment