Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
731f154d
Commit
731f154d
authored
Nov 01, 2022
by
Tri Dao
Browse files
Fix race conditions in the Triton bwd for headdim=64
parent
9b0bc978
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
12 deletions
+20
-12
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+20
-12
No files found.
flash_attn/flash_attn_triton.py
View file @
731f154d
...
@@ -42,7 +42,6 @@ import triton.language as tl
...
@@ -42,7 +42,6 @@ import triton.language as tl
{
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM"
:
lambda
args
:
args
[
"headdim"
]
==
args
[
"BLOCK_HEADDIM"
],
"EVEN_HEADDIM"
:
lambda
args
:
args
[
"headdim"
]
==
args
[
"BLOCK_HEADDIM"
],
}
}
)
)
...
@@ -86,8 +85,8 @@ def _fwd_kernel(
...
@@ -86,8 +85,8 @@ def _fwd_kernel(
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# load q: it will stay in SRAM throughout
# load q: it will stay in SRAM throughout
# [2022-10-30] TD:
Idk why but
in the case of EVEN_M=True and EVEN_N=False, if we just call
# [2022-10-30] TD:
Triton bug -
in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
Could be a bug in the compiler?
# tl.load(q_ptrs), we get the wrong output!
if
EVEN_M
&
EVEN_N
:
if
EVEN_M
&
EVEN_N
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
q
=
tl
.
load
(
q_ptrs
)
q
=
tl
.
load
(
q_ptrs
)
...
@@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block(
...
@@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block(
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
# initialize dv a
m
d dk
# initialize dv a
n
d dk
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
# k and v stay in SRAM throughout
...
@@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block(
...
@@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block(
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
if
not
EVEN_HEADDIM
:
# Also wrong for headdim=64.
if
not
EVEN_M
:
tl
.
debug_barrier
()
tl
.
debug_barrier
()
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
...
@@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block(
...
@@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block(
# the output is correct.
# the output is correct.
if
EVEN_M
&
EVEN_HEADDIM
:
if
EVEN_M
&
EVEN_HEADDIM
:
do
=
tl
.
load
(
do_ptrs
)
do
=
tl
.
load
(
do_ptrs
)
else
:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do
=
tl
.
load
(
do_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
# if EVEN_M:
# if EVEN_M:
# if EVEN_HEADDIM:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# do = tl.load(do_ptrs)
# else:
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else
:
#
else:
if
EVEN_HEADDIM
:
#
if EVEN_HEADDIM:
do
=
tl
.
load
(
do_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
#
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else
:
#
else:
do
=
tl
.
load
(
do_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
#
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
#
& (offs_d[None, :] < headdim), other=0.0)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if
not
EVEN_M
:
if
not
EVEN_M
:
tl
.
debug_barrier
()
tl
.
debug_barrier
()
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
...
@@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block(
...
@@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block(
# write-back
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
if
EVEN_N
:
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if
EVEN_N
&
EVEN_M
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
tl
.
store
(
dk_ptrs
,
dk
)
...
@@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
assert
d
<=
128
assert
d
<=
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum
=
torch
.
empty_like
(
q
,
dtype
=
torch
.
float32
)
dq_accum
=
torch
.
empty_like
(
q
,
dtype
=
torch
.
float32
)
delta
=
torch
.
empty_like
(
lse
)
delta
=
torch
.
empty_like
(
lse
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment