Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
aacc10fb
Commit
aacc10fb
authored
Nov 02, 2022
by
Tri Dao
Browse files
Fix race condition in Triton bwd for non-po2 headdims
parent
1fb12afd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
27 deletions
+23
-27
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+8
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+15
-11
No files found.
flash_attn/flash_attn_triton.py
View file @
aacc10fb
...
...
@@ -7,7 +7,7 @@ Changes:
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not
16, 32,
64, 128 will require
and backward pass. For the backward pass, head dims that are not 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
...
...
@@ -17,9 +17,9 @@ small batch size * nheads.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and
might be
slightly
slower
in other cases
.
- Triton version does ye
t no
t support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton backward is faster than CUDA backward when batch * nheads is small, and
when headdim=64. It is
slightly
slower
when headdim=128 and batch * nheads is large
.
- Triton version does
n't
yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
import
math
...
...
@@ -282,7 +282,7 @@ def _bwd_kernel_one_col_block(
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
if
not
EVEN_M
:
if
not
(
EVEN_M
&
EVEN_HEADDIM
)
:
tl
.
debug_barrier
()
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
...
...
@@ -316,6 +316,9 @@ def _bwd_kernel_one_col_block(
if
not
EVEN_M
:
tl
.
debug_barrier
()
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# There's a race condition for headdim=48
if
not
EVEN_HEADDIM
:
tl
.
debug_barrier
()
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
Di
=
tl
.
load
(
D
+
offs_m_curr
)
...
...
@@ -390,10 +393,6 @@ def _bwd_kernel_one_col_block(
def
init_to_zero
(
name
):
# def fn(nargs):
# with torch.no_grad():
# nargs[name].zero_()
# return fn
return
lambda
nargs
:
nargs
[
name
].
zero_
()
@
triton
.
autotune
(
...
...
@@ -406,15 +405,8 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
],
# reset_to_zero=['DQ']
)
@
triton
.
heuristics
(
{
...
...
tests/test_flash_attn.py
View file @
aacc10fb
import
math
from
functools
import
partial
import
torch
import
torch.nn.functional
as
F
...
...
@@ -858,14 +859,14 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.
b
float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [4
0
])
# @pytest.mark.parametrize('d', [4
8
])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(102
4
, 102
4
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(102
3
, 102
3
)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
...
...
@@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.
b
float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
#
@pytest.mark.parametrize('d', [64])
#
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
1
13
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
9
1
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
...
...
@@ -941,7 +942,10 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
g
=
torch
.
randn_like
(
output_0
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq
=
False
equal_fn
=
(
torch
.
equal
if
deterministic_dq
else
partial
(
torch
.
allclose
,
atol
=
1e-3
if
dtype
==
torch
.
bfloat16
else
1e-5
))
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
...
...
@@ -949,13 +953,13 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
output
,
output_0
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_equal
=
torch
.
equal
(
dq
,
dq_0
)
dq_equal
=
equal
_fn
(
dq
,
dq_0
)
dk_equal
=
torch
.
equal
(
dk
,
dk_0
)
dv_equal
=
torch
.
equal
(
dv
,
dv_0
)
if
not
(
dq_equal
and
dk_equal
and
dv_equal
):
print
(
f
'dQ max diff:
{
(
dq
-
dq_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_0
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
dq
,
dq_0
)
assert
equal
_fn
(
dq
,
dq_0
)
assert
torch
.
equal
(
dk
,
dk_0
)
assert
torch
.
equal
(
dv
,
dv_0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment