Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dc554693
Commit
dc554693
authored
Oct 30, 2022
by
Tri Dao
Browse files
Support arbitrary seqlen_k in Triton bwd
parent
d11341fd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
12 deletions
+34
-12
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+30
-8
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-4
No files found.
flash_attn/flash_attn_triton.py
View file @
dc554693
...
...
@@ -3,11 +3,14 @@ Based on the FlashAttention implementation from Phil Tillet.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes:
- Support both causal and non-causal attention.
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q
must still be a multiple of 128.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling.
-
Add the o
ption
to
parallelize the backward pass across seqlen_k, to deal with the case of
-
O
ption
ally
parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
"""
...
...
@@ -190,6 +193,7 @@ def _bwd_kernel_one_col_block(
ATOMIC_ADD
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
...
...
@@ -209,8 +213,12 @@ def _bwd_kernel_one_col_block(
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
if
EVEN_N
:
k
=
tl
.
load
(
k_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
else
:
k
=
tl
.
load
(
k_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
# loop over rows
num_block_m
=
tl
.
cdiv
(
seqlen_q
,
BLOCK_M
)
for
start_m
in
range
(
begin_m
,
num_block_m
*
BLOCK_M
,
BLOCK_M
):
...
...
@@ -220,6 +228,8 @@ def _bwd_kernel_one_col_block(
q
=
tl
.
load
(
q_ptrs
)
# recompute p = softmax(qk, dim=-1).T
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
not
EVEN_N
:
# Need to mask out otherwise the softmax is wrong
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
...
...
@@ -252,8 +262,12 @@ def _bwd_kernel_one_col_block(
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_k
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_k
[
None
,
:])
if
EVEN_N
:
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
else
:
tl
.
store
(
dv_ptrs
,
dv
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
tl
.
store
(
dk_ptrs
,
dk
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
)
def
init_to_zero
(
name
):
...
...
@@ -282,6 +296,12 @@ def init_to_zero(name):
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
],
# reset_to_zero=['DQ']
)
@
triton
.
heuristics
(
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
(
args
[
"BLOCK_N"
])
==
0
,
}
)
@
triton
.
jit
def
_bwd_kernel
(
Q
,
K
,
V
,
...
...
@@ -300,6 +320,7 @@ def _bwd_kernel(
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
off_hb
=
tl
.
program_id
(
1
)
...
...
@@ -329,6 +350,7 @@ def _bwd_kernel(
ATOMIC_ADD
=
False
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
)
else
:
...
...
@@ -343,6 +365,7 @@ def _bwd_kernel(
ATOMIC_ADD
=
True
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
)
...
...
@@ -394,8 +417,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do
=
do
.
contiguous
()
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
seqlen_q
%
128
==
0
,
'Backward pass currently only support seqlen that are multiples of 128'
assert
seqlen_k
%
128
==
0
,
'Backward pass currently only support seqlen that are multiples of 128'
assert
seqlen_q
%
128
==
0
,
'Backward pass currently only supports seqlens that are multiples of 128'
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
...
...
tests/test_flash_attn.py
View file @
dc554693
...
...
@@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Fals
e])
# @pytest.mark.parametrize('causal', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(12
7
, 2
56
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(12
8
, 2
11
)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
...
...
@@ -887,7 +887,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
run_bwd
=
(
seqlen_q
%
128
==
0
)
and
(
seqlen_k
%
128
==
0
)
run_bwd
=
seqlen_q
%
128
==
0
if
run_bwd
:
g
=
torch
.
randn_like
(
output
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment