Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
008951f1
Commit
008951f1
authored
Oct 30, 2022
by
Tri Dao
Browse files
Support all head dimensions up to 128 in the Triton fwd
parent
b910bf14
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
31 deletions
+68
-31
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+51
-17
tests/test_flash_attn.py
tests/test_flash_attn.py
+17
-14
No files found.
flash_attn/flash_attn_triton.py
View file @
008951f1
...
...
@@ -6,6 +6,7 @@ Changes:
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), in the forward pass.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
...
...
@@ -31,6 +32,7 @@ import triton.language as tl
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
"EVEN_HEADDIM"
:
lambda
args
:
args
[
"headdim"
]
==
args
[
"BLOCK_HEADDIM"
],
}
)
@
triton
.
jit
...
...
@@ -42,11 +44,11 @@ def _fwd_kernel(
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_ob
,
stride_oh
,
stride_om
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
...
...
@@ -76,19 +78,34 @@ def _fwd_kernel(
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if
EVEN_M
&
EVEN_N
:
if
EVEN_HEADDIM
:
q
=
tl
.
load
(
q_ptrs
)
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
else
:
if
EVEN_HEADDIM
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
# loop over k, v and update accumulator
end_n
=
seqlen_k
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
seqlen_k
)
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
if
EVEN_N
:
if
EVEN_HEADDIM
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
else
:
if
EVEN_HEADDIM
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
not
EVEN_N
:
...
...
@@ -111,10 +128,18 @@ def _fwd_kernel(
acc_o
=
acc_o
*
acc_o_scale
[:,
None
]
# update acc_o
if
EVEN_N
:
if
EVEN_HEADDIM
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
else
:
if
EVEN_HEADDIM
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc_o
+=
tl
.
dot
(
p
,
v
)
...
...
@@ -138,9 +163,16 @@ def _fwd_kernel(
offs_n
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
out_ptrs
=
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
(
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:])
if
EVEN_M
:
if
EVEN_HEADDIM
:
tl
.
store
(
out_ptrs
,
acc_o
)
else
:
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
offs_d
[
None
,
:]
<
headdim
)
else
:
if
EVEN_HEADDIM
:
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
)
else
:
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
))
@
triton
.
heuristics
(
...
...
@@ -209,8 +241,8 @@ def _bwd_kernel_one_col_block(
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_
N
=False,
# if we just call
#
tl.load(k_ptrs), we get the wrong output!
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_
M
=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
if
EVEN_N
&
EVEN_M
:
k
=
tl
.
load
(
k_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
...
...
@@ -390,7 +422,8 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
v
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
d
in
{
16
,
32
,
64
,
128
}
assert
d
<=
128
,
'FlashAttention only support head dimensions up to 128'
BLOCK_HEADDIM
=
max
(
triton
.
next_power_of_2
(
d
),
16
)
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
'All tensors must have the same type'
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
'Only support fp16 and bf16'
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
...
...
@@ -413,11 +446,11 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal
,
d
,
causal
,
BLOCK_HEADDIM
,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps,
# num_stages=1,
...
...
@@ -431,6 +464,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do
=
do
.
contiguous
()
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
d
in
{
16
,
32
,
64
,
128
}
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
...
...
tests/test_flash_attn.py
View file @
008951f1
...
...
@@ -861,7 +861,7 @@ from flash_attn.flash_attn_triton import flash_attn_func
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
64
,
128
,
88
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
...
...
@@ -887,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
run_bwd
=
d
in
[
16
,
32
,
64
,
128
]
if
run_bwd
:
g
=
torch
.
randn_like
(
output
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
...
...
@@ -903,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if
run_bwd
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment