Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9b0bc978
Commit
9b0bc978
authored
Oct 31, 2022
by
Tri Dao
Browse files
Fix race condition in Triton fwd
parent
215930bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
5 deletions
+56
-5
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+15
-5
tests/test_flash_attn.py
tests/test_flash_attn.py
+41
-0
No files found.
flash_attn/flash_attn_triton.py
View file @
9b0bc978
"""
"""
Based on
the FlashAttention implementation from Phil Tillet.
We use
the FlashAttention implementation from Phil Tillet
a starting point
.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes:
Changes:
...
@@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
...
@@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
- Make the backward for d=128 much faster by reducing register spilling.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
small batch size * nheads.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
slower in other cases.
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
"""
import
math
import
math
...
@@ -26,7 +33,8 @@ import triton.language as tl
...
@@ -26,7 +33,8 @@ import triton.language as tl
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
# This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
],
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
)
)
...
@@ -34,6 +42,7 @@ import triton.language as tl
...
@@ -34,6 +42,7 @@ import triton.language as tl
{
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM"
:
lambda
args
:
args
[
"headdim"
]
==
args
[
"BLOCK_HEADDIM"
],
"EVEN_HEADDIM"
:
lambda
args
:
args
[
"headdim"
]
==
args
[
"BLOCK_HEADDIM"
],
}
}
)
)
...
@@ -95,7 +104,7 @@ def _fwd_kernel(
...
@@ -95,7 +104,7 @@ def _fwd_kernel(
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
# -- compute qk ----
if
EVEN_N
:
if
EVEN_N
&
EVEN_M
:
# If we just do "if EVEN_N", there seems to be some race condition
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
)
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
)
else
:
else
:
...
@@ -129,7 +138,7 @@ def _fwd_kernel(
...
@@ -129,7 +138,7 @@ def _fwd_kernel(
acc_o_scale
=
tl
.
load
(
t_ptrs
)
acc_o_scale
=
tl
.
load
(
t_ptrs
)
acc_o
=
acc_o
*
acc_o_scale
[:,
None
]
acc_o
=
acc_o
*
acc_o_scale
[:,
None
]
# update acc_o
# update acc_o
if
EVEN_N
:
if
EVEN_N
&
EVEN_M
:
# If we just do "if EVEN_N", there seems to be some race condition
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
)
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
)
else
:
else
:
...
@@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
...
@@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
# compute dp = dot(v, do)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
tl
.
debug_barrier
()
if
not
EVEN_M
:
tl
.
debug_barrier
()
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# compute ds = p * (dp - delta[:, None])
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
...
...
tests/test_flash_attn.py
View file @
9b0bc978
...
@@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output_0
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
g
=
torch
.
randn_like
(
output_0
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert
torch
.
equal
(
output
,
output_0
)
# assert torch.equal(dq, dq_0)
# assert torch.equal(dk, dk_0)
# assert torch.equal(dv, dv_0)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment