Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
74797571
Commit
74797571
authored
Nov 06, 2022
by
Tri Dao
Browse files
Fix pipelining bug in Triton bwd with bias_type=matrix
parent
55778193
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
26 deletions
+46
-26
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+36
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+10
-10
No files found.
flash_attn/flash_attn_triton.py
View file @
74797571
...
...
@@ -18,6 +18,7 @@ small batch size * nheads.
Caution:
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
- This implementation has only been tested on A100.
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
...
...
@@ -250,6 +251,29 @@ def _bwd_preprocess_do_o_dot(
tl
.
store
(
Delta
+
off_hb
*
seqlen_q_rounded
+
offs_m
,
delta
)
@
triton
.
jit
def
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
):
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if
EVEN_N
&
EVEN_M
:
if
EVEN_HEADDIM
:
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
else
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
else
:
if
EVEN_HEADDIM
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
else
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
))
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
))
@
triton
.
jit
def
_bwd_kernel_one_col_block
(
start_n
,
...
...
@@ -287,6 +311,16 @@ def _bwd_kernel_one_col_block(
# initialize dv and dk
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# There seems to be some problem with Triton pipelining that makes results wrong for
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
# may have zero step, and pipelining with the bias matrix could screw it up.
# So we just exit early.
if
begin_m
>=
seqlen_q
:
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
)
return
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
...
...
@@ -437,22 +471,8 @@ def _bwd_kernel_one_col_block(
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if
EVEN_N
&
EVEN_M
:
if
EVEN_HEADDIM
:
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
else
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
else
:
if
EVEN_HEADDIM
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
else
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
))
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
))
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
)
def
init_to_zero
(
name
):
...
...
tests/test_flash_attn.py
View file @
74797571
...
...
@@ -864,14 +864,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Fals
e])
# @pytest.mark.parametrize('causal', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('d', [48])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
256, 128
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
1024, 1023
)])
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
# @pytest.mark.parametrize('bias_shape', (['1h
1
k']))
# @pytest.mark.parametrize('bias_shape', (['1h
q
k']))
def
test_flash_attn_triton_output
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
...
...
@@ -935,13 +934,13 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Fals
e])
# @pytest.mark.parametrize('causal', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [96])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
91
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
256, 512
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
113, 203
)])
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
# @pytest.mark.parametrize('bias_shape', (['b1qk']))
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
...
...
@@ -979,6 +978,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
output
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
print
(
f
'
{
dtype
=
}
,
{
causal
=
}
,
{
d
=
}
,
{
seqlen_q
=
}
,
{
seqlen_k
=
}
,
{
bias_shape
=
}
,
{
i
=
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
output
,
output_0
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
...
...
@@ -986,7 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
dk_equal
=
torch
.
equal
(
dk
,
dk_0
)
dv_equal
=
torch
.
equal
(
dv
,
dv_0
)
if
not
(
dq_equal
and
dk_equal
and
dv_equal
):
print
(
f
'
{
i
=
}
'
)
print
(
f
'
{
dtype
=
}
,
{
causal
=
}
,
{
d
=
}
,
{
seqlen_q
=
}
,
{
seqlen_k
=
}
,
{
bias_shape
=
}
,
{
i
=
}
'
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_0
).
abs
().
max
().
item
()
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment