Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1fb12afd
Commit
1fb12afd
authored
Nov 01, 2022
by
Tri Dao
Browse files
Avoid memcpy in the Triton bwd
parent
731f154d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
23 deletions
+37
-23
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+23
-15
tests/test_flash_attn.py
tests/test_flash_attn.py
+14
-8
No files found.
flash_attn/flash_attn_triton.py
View file @
1fb12afd
...
@@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
)
)
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
# BLOCK_M = 128
# BLOCK_M = 128
# BLOCK_N = 64
# BLOCK_N = 64
# num_warps = 4
# num_warps = 4
...
@@ -610,10 +609,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -610,10 +609,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
qkv
,
o
,
lse
=
ctx
.
saved_tensors
qkv
,
o
,
lse
=
ctx
.
saved_tensors
dqkv
=
torch
.
empty_like
(
qkv
)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
with
torch
.
inference_mode
():
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dqkv
,
None
,
None
return
dqkv
,
None
,
None
...
@@ -640,11 +642,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -640,11 +642,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
kv
,
o
,
lse
=
ctx
.
saved_tensors
q
,
kv
,
o
,
lse
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
dkv
=
torch
.
empty_like
(
kv
)
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
_flash_attn_backward
(
do
,
q
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
o
,
lse
,
with
torch
.
inference_mode
():
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
dq
=
torch
.
empty_like
(
q
)
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
do
,
q
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
o
,
lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dkv
,
None
,
None
return
dq
,
dkv
,
None
,
None
...
@@ -669,11 +674,14 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -669,11 +674,14 @@ class FlashAttnFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
dk
=
torch
.
empty_like
(
k
)
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
dv
=
torch
.
empty_like
(
v
)
with
torch
.
inference_mode
():
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
dq
=
torch
.
empty_like
(
q
)
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dk
,
dv
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
...
...
tests/test_flash_attn.py
View file @
1fb12afd
...
@@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
for
i
in
range
(
10000
):
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
output_equal
=
torch
.
equal
(
output
,
output_0
)
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert
torch
.
equal
(
output
,
output_0
)
assert
torch
.
equal
(
output
,
output_0
)
# assert torch.equal(dq, dq_0)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
# assert torch.equal(dk, dk_0)
dq_equal
=
torch
.
equal
(
dq
,
dq_0
)
# assert torch.equal(dv, dv_0)
dk_equal
=
torch
.
equal
(
dk
,
dk_0
)
dv_equal
=
torch
.
equal
(
dv
,
dv_0
)
if
not
(
dq_equal
and
dk_equal
and
dv_equal
):
print
(
f
'dQ max diff:
{
(
dq
-
dq_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_0
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
dq
,
dq_0
)
assert
torch
.
equal
(
dk
,
dk_0
)
assert
torch
.
equal
(
dv
,
dv_0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment