Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b910bf14
Commit
b910bf14
authored
Oct 30, 2022
by
Tri Dao
Browse files
Support arbitrary seqlens (both q & k) in Triton bwd
parent
dc554693
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
32 deletions
+42
-32
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+27
-14
tests/test_flash_attn.py
tests/test_flash_attn.py
+15
-18
No files found.
flash_attn/flash_attn_triton.py
View file @
b910bf14
...
@@ -5,10 +5,8 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
...
@@ -5,10 +5,8 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
Changes:
- Implement both causal and non-causal attention.
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
must still be a multiple of 128.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
small batch size * nheads.
...
@@ -18,8 +16,6 @@ import math
...
@@ -18,8 +16,6 @@ import math
import
torch
import
torch
from
einops
import
rearrange
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -213,7 +209,9 @@ def _bwd_kernel_one_col_block(
...
@@ -213,7 +209,9 @@ def _bwd_kernel_one_col_block(
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
# k and v stay in SRAM throughout
if
EVEN_N
:
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False,
# if we just call # tl.load(k_ptrs), we get the wrong output!
if
EVEN_N
&
EVEN_M
:
k
=
tl
.
load
(
k_ptrs
)
k
=
tl
.
load
(
k_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
else
:
else
:
...
@@ -225,7 +223,10 @@ def _bwd_kernel_one_col_block(
...
@@ -225,7 +223,10 @@ def _bwd_kernel_one_col_block(
start_m
=
tl
.
multiple_of
(
start_m
,
BLOCK_M
)
start_m
=
tl
.
multiple_of
(
start_m
,
BLOCK_M
)
offs_m_curr
=
start_m
+
offs_m
offs_m_curr
=
start_m
+
offs_m
# load q, k, v, do on-chip
# load q, k, v, do on-chip
q
=
tl
.
load
(
q_ptrs
)
if
EVEN_M
:
q
=
tl
.
load
(
q_ptrs
)
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
# recompute p = softmax(qk, dim=-1).T
# recompute p = softmax(qk, dim=-1).T
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
not
EVEN_N
:
# Need to mask out otherwise the softmax is wrong
if
not
EVEN_N
:
# Need to mask out otherwise the softmax is wrong
...
@@ -235,7 +236,10 @@ def _bwd_kernel_one_col_block(
...
@@ -235,7 +236,10 @@ def _bwd_kernel_one_col_block(
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
# compute dv
# compute dv
do
=
tl
.
load
(
do_ptrs
)
if
EVEN_M
:
do
=
tl
.
load
(
do_ptrs
)
else
:
do
=
tl
.
load
(
do_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
# compute dp = dot(v, do)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
...
@@ -249,12 +253,22 @@ def _bwd_kernel_one_col_block(
...
@@ -249,12 +253,22 @@ def _bwd_kernel_one_col_block(
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
# compute dq
# compute dq
if
not
ATOMIC_ADD
:
if
not
ATOMIC_ADD
:
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
if
EVEN_M
:
dq
+=
tl
.
dot
(
ds
,
k
)
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
else
:
dq
=
tl
.
load
(
dq_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
eviction_policy
=
"evict_last"
)
else
:
# If we're parallelizing across the seqlen_k dimension
else
:
# If we're parallelizing across the seqlen_k dimension
dq
=
tl
.
dot
(
ds
,
k
)
dq
=
tl
.
dot
(
ds
,
k
)
tl
.
atomic_add
(
dq_ptrs
,
dq
)
if
EVEN_M
:
tl
.
atomic_add
(
dq_ptrs
,
dq
)
else
:
tl
.
atomic_add
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
)
# increment pointers
# increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_dqm
dq_ptrs
+=
BLOCK_M
*
stride_dqm
q_ptrs
+=
BLOCK_M
*
stride_qm
q_ptrs
+=
BLOCK_M
*
stride_qm
...
@@ -417,7 +431,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -417,7 +431,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do
=
do
.
contiguous
()
do
=
do
.
contiguous
()
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
seqlen_q
%
128
==
0
,
'Backward pass currently only supports seqlens that are multiples of 128'
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
...
...
tests/test_flash_attn.py
View file @
b910bf14
...
@@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
...
@@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Tru
e])
# @pytest.mark.parametrize('causal', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1
28, 211
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1
13, 128
)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
...
@@ -887,25 +887,22 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -887,25 +887,22 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
run_bwd
=
seqlen_q
%
128
==
0
g
=
torch
.
randn_like
(
output
)
if
run_bwd
:
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
g
=
torch
.
randn_like
(
output
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_pt
,
dk_pt
,
dv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
dq_pt
,
dk_pt
,
dv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
'dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
# of a Pytorch implementation.
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if
run_bwd
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment