Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
86862cfd
Commit
86862cfd
authored
Nov 04, 2022
by
Tri Dao
Browse files
Implement attention bias for Triton version
parent
470010f5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
225 additions
and
64 deletions
+225
-64
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+180
-48
tests/test_flash_attn.py
tests/test_flash_attn.py
+45
-16
No files found.
flash_attn/flash_attn_triton.py
View file @
86862cfd
"""
"""
*Experimental* implementation of FlashAttention in Triton.
We use the FlashAttention implementation from Phil Tillet a starting point.
We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
...
@@ -7,6 +9,7 @@ Changes:
...
@@ -7,6 +9,7 @@ Changes:
- Implement both self-attention and cross-attention.
- Implement both self-attention and cross-attention.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Support attention bias.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
...
@@ -31,6 +34,8 @@ import math
...
@@ -31,6 +34,8 @@ import math
import
torch
import
torch
from
einops
import
rearrange
,
repeat
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -41,7 +46,7 @@ import triton.language as tl
...
@@ -41,7 +46,7 @@ import triton.language as tl
# This config has a race condition when EVEN_M == False, disabling it for now.
# This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
],
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'BIAS_TYPE'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
)
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
...
@@ -52,15 +57,17 @@ import triton.language as tl
...
@@ -52,15 +57,17 @@ import triton.language as tl
)
)
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
Q
,
K
,
V
,
Out
,
Q
,
K
,
V
,
Bias
,
Out
,
Lse
,
TMP
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Lse
,
TMP
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale
,
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_ob
,
stride_oh
,
stride_om
,
stride_ob
,
stride_oh
,
stride_om
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
...
@@ -84,6 +91,10 @@ def _fwd_kernel(
...
@@ -84,6 +91,10 @@ def _fwd_kernel(
q_ptrs
=
Q
+
off_b
*
stride_qb
+
off_h
*
stride_qh
+
(
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:])
q_ptrs
=
Q
+
off_b
*
stride_qb
+
off_h
*
stride_qh
+
(
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:])
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
if
BIAS_TYPE
==
'vector'
:
b_ptrs
=
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
offs_n
elif
BIAS_TYPE
==
'matrix'
:
b_ptrs
=
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
(
offs_m
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
# initialize pointer to m and l
# initialize pointer to m and l
t_ptrs
=
TMP
+
off_hb
*
seqlen_q_rounded
+
offs_m
t_ptrs
=
TMP
+
off_hb
*
seqlen_q_rounded
+
offs_m
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
...
@@ -123,14 +134,35 @@ def _fwd_kernel(
...
@@ -123,14 +134,35 @@ def _fwd_kernel(
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
not
EVEN_N
:
# Trying to combine the two masks seem to make the result wrong
if
not
EVEN_N
:
# Need to mask out otherwise the softmax is wrong
qk
+=
tl
.
where
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
,
0
,
float
(
"-inf"
))
qk
+=
tl
.
where
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
,
0
,
float
(
"-inf"
))
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
)[
None
,
:],
0
,
float
(
"-inf"
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
)[
None
,
:],
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
)
*
softmax_scale
,
lse_i
)
if
BIAS_TYPE
!=
'none'
:
# Slightly faster to multiply the softmax_scale here since the compiler can then
if
BIAS_TYPE
==
'vector'
:
# fuse the mult and add into an fma instruction.
if
EVEN_N
:
p
=
tl
.
exp
(
qk
*
softmax_scale
-
m_ij
[:,
None
])
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
else
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
mask
=
(
start_n
+
offs_n
)
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
bias
[
None
,
:]
elif
BIAS_TYPE
==
'matrix'
:
if
EVEN_M
&
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
else
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
),
other
=
0.0
).
to
(
tl
.
float32
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
qk
=
qk
*
softmax_scale
+
bias
m_ij
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
lse_i
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
else
:
m_ij
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
)
*
softmax_scale
,
lse_i
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
l_ij
=
tl
.
sum
(
p
,
1
)
# scale acc_o
# scale acc_o
...
@@ -218,12 +250,15 @@ def _bwd_preprocess_do_o_dot(
...
@@ -218,12 +250,15 @@ def _bwd_preprocess_do_o_dot(
@
triton
.
jit
@
triton
.
jit
def
_bwd_kernel_one_col_block
(
def
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
softmax_scale
,
Q
,
K
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
:
tl
.
constexpr
,
ATOMIC_ADD
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
...
@@ -242,6 +277,10 @@ def _bwd_kernel_one_col_block(
...
@@ -242,6 +277,10 @@ def _bwd_kernel_one_col_block(
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
if
BIAS_TYPE
==
'vector'
:
b_ptrs
=
Bias
+
offs_n
elif
BIAS_TYPE
==
'matrix'
:
b_ptrs
=
Bias
+
(
offs_qm
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
# initialize dv and dk
# initialize dv and dk
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
...
@@ -286,12 +325,31 @@ def _bwd_kernel_one_col_block(
...
@@ -286,12 +325,31 @@ def _bwd_kernel_one_col_block(
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
if
BIAS_TYPE
!=
'none'
:
if
BIAS_TYPE
==
'vector'
:
if
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
else
:
bias
=
tl
.
load
(
b_ptrs
,
mask
=
offs_n
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
bias
[
None
,
:]
elif
BIAS_TYPE
==
'matrix'
:
if
EVEN_M
&
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
else
:
bias
=
tl
.
load
(
b_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_n
[
None
,
:]
<
seqlen_k
),
other
=
0.0
).
to
(
tl
.
float32
)
qk
=
qk
*
softmax_scale
+
bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
# Also wrong for headdim=64.
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
tl
.
debug_barrier
()
tl
.
debug_barrier
()
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
if
BIAS_TYPE
==
'none'
:
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
else
:
p
=
tl
.
exp
(
qk
-
lse_i
[:,
None
])
# compute dv
# compute dv
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
...
@@ -368,6 +426,8 @@ def _bwd_kernel_one_col_block(
...
@@ -368,6 +426,8 @@ def _bwd_kernel_one_col_block(
dq_ptrs
+=
BLOCK_M
*
stride_dqm
dq_ptrs
+=
BLOCK_M
*
stride_dqm
q_ptrs
+=
BLOCK_M
*
stride_qm
q_ptrs
+=
BLOCK_M
*
stride_qm
do_ptrs
+=
BLOCK_M
*
stride_dom
do_ptrs
+=
BLOCK_M
*
stride_dom
if
BIAS_TYPE
==
'matrix'
:
b_ptrs
+=
BLOCK_M
*
stride_bm
# write-back
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
...
@@ -392,6 +452,7 @@ def _bwd_kernel_one_col_block(
...
@@ -392,6 +452,7 @@ def _bwd_kernel_one_col_block(
def
init_to_zero
(
name
):
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
return
lambda
nargs
:
nargs
[
name
].
zero_
()
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
...
@@ -403,7 +464,7 @@ def init_to_zero(name):
...
@@ -403,7 +464,7 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
],
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'BIAS_TYPE'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
],
)
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
...
@@ -414,19 +475,21 @@ def init_to_zero(name):
...
@@ -414,19 +475,21 @@ def init_to_zero(name):
)
)
@
triton
.
jit
@
triton
.
jit
def
_bwd_kernel
(
def
_bwd_kernel
(
Q
,
K
,
V
,
Q
,
K
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
LSE
,
D
,
softmax_scale
,
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_dob
,
stride_doh
,
stride_dom
,
stride_dob
,
stride_doh
,
stride_dom
,
stride_dqb
,
stride_dqh
,
stride_dqm
,
stride_dqb
,
stride_dqh
,
stride_dqm
,
stride_dkb
,
stride_dkh
,
stride_dkn
,
stride_dkb
,
stride_dkh
,
stride_dkn
,
stride_dvb
,
stride_dvh
,
stride_dvn
,
stride_dvb
,
stride_dvh
,
stride_dvn
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
...
@@ -444,6 +507,8 @@ def _bwd_kernel(
...
@@ -444,6 +507,8 @@ def _bwd_kernel(
DQ
+=
off_b
*
stride_dqb
+
off_h
*
stride_dqh
DQ
+=
off_b
*
stride_dqb
+
off_h
*
stride_dqh
DK
+=
off_b
*
stride_dkb
+
off_h
*
stride_dkh
DK
+=
off_b
*
stride_dkb
+
off_h
*
stride_dkh
DV
+=
off_b
*
stride_dvb
+
off_h
*
stride_dvh
DV
+=
off_b
*
stride_dvb
+
off_h
*
stride_dvh
if
BIAS_TYPE
!=
'none'
:
Bias
+=
off_b
*
stride_bb
+
off_h
*
stride_bh
# pointer to row-wise quantities in value-like data
# pointer to row-wise quantities in value-like data
D
+=
off_hb
*
seqlen_q_rounded
D
+=
off_hb
*
seqlen_q_rounded
LSE
+=
off_hb
*
seqlen_q_rounded
LSE
+=
off_hb
*
seqlen_q_rounded
...
@@ -452,12 +517,15 @@ def _bwd_kernel(
...
@@ -452,12 +517,15 @@ def _bwd_kernel(
for
start_n
in
range
(
0
,
num_block_n
):
for
start_n
in
range
(
0
,
num_block_n
):
_bwd_kernel_one_col_block
(
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
softmax_scale
,
Q
,
K
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
=
False
,
ATOMIC_ADD
=
False
,
BIAS_TYPE
=
BIAS_TYPE
,
IS_CAUSAL
=
IS_CAUSAL
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
...
@@ -467,12 +535,15 @@ def _bwd_kernel(
...
@@ -467,12 +535,15 @@ def _bwd_kernel(
start_n
=
tl
.
program_id
(
0
)
start_n
=
tl
.
program_id
(
0
)
_bwd_kernel_one_col_block
(
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
softmax_scale
,
Q
,
K
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
=
True
,
ATOMIC_ADD
=
True
,
BIAS_TYPE
=
BIAS_TYPE
,
IS_CAUSAL
=
IS_CAUSAL
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
...
@@ -480,7 +551,7 @@ def _bwd_kernel(
...
@@ -480,7 +551,7 @@ def _bwd_kernel(
)
)
def
_flash_attn_forward
(
q
,
k
,
v
,
causal
=
False
,
softmax_scale
=
None
):
def
_flash_attn_forward
(
q
,
k
,
v
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
# shape constraints
# shape constraints
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
...
@@ -491,10 +562,31 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
...
@@ -491,10 +562,31 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
'Only support fp16 and bf16'
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
'Only support fp16 and bf16'
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
has_bias
=
bias
is
not
None
bias_type
=
'none'
if
has_bias
:
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
is_cuda
assert
bias
.
dim
()
==
4
if
bias
.
stride
(
-
1
)
!=
1
:
bias
=
bias
.
contiguous
()
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
bias_type
=
'vector'
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
bias_type
=
'matrix'
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)'
)
if
bias
.
shape
[:
2
]
==
(
1
,
nheads
):
bias
=
repeat
(
bias
,
'1 h ... -> b h ...'
,
b
=
batch
)
elif
bias
.
shape
[:
2
]
==
(
batch
,
1
):
bias
=
repeat
(
bias
,
'b 1 ... -> b h ...'
,
h
=
nheads
)
assert
bias
.
shape
[:
2
]
==
(
batch
,
nheads
),
'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
lse
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device,
# dtype=torch.float32)
tmp
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
tmp
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
...
@@ -503,18 +595,19 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
...
@@ -503,18 +595,19 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
# num_warps = 4 if d <= 64 else 8
# num_warps = 4 if d <= 64 else 8
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q
,
k
,
v
,
o
,
q
,
k
,
v
,
bias
,
o
,
lse
,
tmp
,
lse
,
tmp
,
softmax_scale
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
*
bias_strides
,
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal
,
BLOCK_HEADDIM
,
bias_type
,
causal
,
BLOCK_HEADDIM
,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps,
# num_warps=num_warps,
# num_stages=1,
# num_stages=1,
...
@@ -522,7 +615,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
...
@@ -522,7 +615,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
def
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
causal
=
False
,
softmax_scale
=
None
):
def
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
# Make sure that the last dimension is contiguous
# Make sure that the last dimension is contiguous
if
do
.
stride
(
-
1
)
!=
1
:
if
do
.
stride
(
-
1
)
!=
1
:
do
=
do
.
contiguous
()
do
=
do
.
contiguous
()
...
@@ -532,6 +625,8 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -532,6 +625,8 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
assert
d
<=
128
assert
d
<=
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
assert
q
.
stride
(
-
1
)
==
k
.
stride
(
-
1
)
==
v
.
stride
(
-
1
)
==
o
.
stride
(
-
1
)
==
1
assert
dq
.
stride
(
-
1
)
==
dk
.
stride
(
-
1
)
==
dv
.
stride
(
-
1
)
==
1
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum
=
torch
.
empty_like
(
q
,
dtype
=
torch
.
float32
)
dq_accum
=
torch
.
empty_like
(
q
,
dtype
=
torch
.
float32
)
...
@@ -548,19 +643,41 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -548,19 +643,41 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
)
)
has_bias
=
bias
is
not
None
bias_type
=
'none'
if
has_bias
:
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
is_cuda
assert
bias
.
dim
()
==
4
assert
bias
.
stride
(
-
1
)
==
1
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
bias_type
=
'vector'
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
bias_type
=
'matrix'
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)'
)
if
bias
.
shape
[:
2
]
==
(
1
,
nheads
):
bias
=
repeat
(
bias
,
'1 h ... -> b h ...'
,
b
=
batch
)
elif
bias
.
shape
[:
2
]
==
(
batch
,
1
):
bias
=
repeat
(
bias
,
'b 1 ... -> b h ...'
,
h
=
nheads
)
assert
bias
.
shape
[:
2
]
==
(
batch
,
nheads
),
'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
# BLOCK_M = 128
# BLOCK_M = 128
# BLOCK_N = 64
# BLOCK_N = 64
# num_warps = 4
# num_warps = 4
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_k
,
META
[
"BLOCK_N"
])
if
META
[
"SEQUENCE_PARALLEL"
]
else
1
,
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_k
,
META
[
"BLOCK_N"
])
if
META
[
"SEQUENCE_PARALLEL"
]
else
1
,
batch
*
nheads
)
batch
*
nheads
)
_bwd_kernel
[
grid
](
_bwd_kernel
[
grid
](
q
,
k
,
v
,
q
,
k
,
v
,
bias
,
do
,
dq_accum
,
dk
,
dv
,
do
,
dq_accum
,
dk
,
dv
,
lse
,
delta
,
lse
,
delta
,
softmax_scale
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
*
bias_strides
,
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
dq_accum
.
stride
(
0
),
dq_accum
.
stride
(
2
),
dq_accum
.
stride
(
1
),
dq_accum
.
stride
(
0
),
dq_accum
.
stride
(
2
),
dq_accum
.
stride
(
1
),
dk
.
stride
(
0
),
dk
.
stride
(
2
),
dk
.
stride
(
1
),
dk
.
stride
(
0
),
dk
.
stride
(
2
),
dk
.
stride
(
1
),
...
@@ -569,7 +686,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -569,7 +686,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal
,
BLOCK_HEADDIM
,
bias_type
,
causal
,
BLOCK_HEADDIM
,
# SEQUENCE_PARALLEL=False,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps,
# num_warps=num_warps,
...
@@ -581,31 +698,36 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
...
@@ -581,31 +698,36 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
qkv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim)
qkv: (batch, seqlen, 3, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
"""
"""
# Make sure that the last dimension is contiguous
# Make sure that the last dimension is contiguous
if
qkv
.
stride
(
-
1
)
!=
1
:
if
qkv
.
stride
(
-
1
)
!=
1
:
qkv
=
qkv
.
contiguous
()
qkv
=
qkv
.
contiguous
()
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
causal
=
causal
,
softmax_scale
=
softmax_scale
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
bias
=
bias
,
causal
=
causal
,
softmax_scale
=
softmax_scale
)
)
ctx
.
save_for_backward
(
qkv
,
o
,
lse
)
ctx
.
save_for_backward
(
qkv
,
o
,
lse
,
bias
)
ctx
.
causal
=
causal
ctx
.
causal
=
causal
return
o
return
o
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
qkv
,
o
,
lse
=
ctx
.
saved_tensors
qkv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
assert
not
ctx
.
needs_input_grad
[
1
],
'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
dqkv
=
torch
.
empty_like
(
qkv
)
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dqkv
,
None
,
None
return
dqkv
,
None
,
None
,
None
flash_attn_qkvpacked_func
=
FlashAttnQKVPackedFunc
.
apply
flash_attn_qkvpacked_func
=
FlashAttnQKVPackedFunc
.
apply
...
@@ -614,32 +736,36 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
...
@@ -614,32 +736,36 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
q
,
kv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
q: (batch, seqlen, nheads, headdim)
q: (batch, seqlen_q, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
kv: (batch, seqlen_k, 2, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
"""
"""
# Make sure that the last dimension is contiguous
# Make sure that the last dimension is contiguous
q
,
kv
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
kv
]]
q
,
kv
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
kv
]]
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
causal
=
causal
,
softmax_scale
=
softmax_scale
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
bias
=
bias
,
causal
=
causal
,
softmax_scale
=
softmax_scale
)
)
ctx
.
save_for_backward
(
q
,
kv
,
o
,
lse
)
ctx
.
save_for_backward
(
q
,
kv
,
o
,
lse
,
bias
)
ctx
.
causal
=
causal
ctx
.
causal
=
causal
return
o
return
o
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
kv
,
o
,
lse
=
ctx
.
saved_tensors
q
,
kv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
assert
not
ctx
.
needs_input_grad
[
2
],
'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
dq
=
torch
.
empty_like
(
q
)
dq
=
torch
.
empty_like
(
q
)
dkv
=
torch
.
empty_like
(
kv
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
do
,
q
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
o
,
lse
,
_flash_attn_backward
(
do
,
q
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
o
,
lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dkv
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
flash_attn_kvpacked_func
=
FlashAttnKVPackedFunc
.
apply
flash_attn_kvpacked_func
=
FlashAttnKVPackedFunc
.
apply
...
@@ -648,21 +774,27 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
...
@@ -648,21 +774,27 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
q
,
k
,
v
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
q, k, v: (batch_size, seqlen, nheads, headdim)
q: (batch_size, seqlen_q, nheads, headdim)
k, v: (batch_size, seqlen_k, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
"""
"""
# Make sure that the last dimension is contiguous
# Make sure that the last dimension is contiguous
q
,
k
,
v
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
k
,
v
]]
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
q
,
k
,
v
,
causal
=
causal
,
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
softmax_scale
=
softmax_scale
)
q
,
k
,
v
,
bias
=
bias
,
causal
=
causal
,
softmax_scale
=
softmax_scale
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
,
bias
)
ctx
.
causal
=
causal
ctx
.
causal
=
causal
return
o
return
o
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
q
,
k
,
v
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
assert
not
ctx
.
needs_input_grad
[
3
],
'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
...
@@ -670,8 +802,8 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -670,8 +802,8 @@ class FlashAttnFunc(torch.autograd.Function):
dk
=
torch
.
empty_like
(
k
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
dv
=
torch
.
empty_like
(
v
)
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dk
,
dv
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
flash_attn_func
=
FlashAttnFunc
.
apply
flash_attn_func
=
FlashAttnFunc
.
apply
tests/test_flash_attn.py
View file @
86862cfd
...
@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
...
@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
upcast
=
True
,
reorder_ops
=
False
):
dropout_mask
=
None
,
causal
=
False
,
bias
=
None
,
upcast
=
True
,
reorder_ops
=
False
):
"""
"""
Arguments:
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
q: (batch_size, seqlen_q, nheads, head_dim)
...
@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
...
@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
key_padding_mask: (batch_size, seqlen_k)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...
@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
...
@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
/
math
.
sqrt
(
d
),
k
)
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
else
:
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
/
math
.
sqrt
(
d
))
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
/
math
.
sqrt
(
d
))
if
bias
is
not
None
:
scores
=
(
scores
+
bias
).
to
(
dtype
=
scores
.
dtype
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
'b s -> b 1 1 s'
),
float
(
'-inf'
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
'b s -> b 1 1 s'
),
float
(
'-inf'
))
if
causal
:
if
causal
:
...
@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
...
@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [4
8
])
# @pytest.mark.parametrize('d', [
6
4])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
def
test_flash_attn_triton_output
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
device
=
'cuda'
...
@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
nheads
=
4
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
if
bias_shape
==
'1h1k'
:
bias
=
torch
.
randn
(
1
,
nheads
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'1hqk'
:
bias
=
torch
.
randn
(
1
,
nheads
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b11k'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b1qk'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
else
:
bias
=
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
bias
=
bias
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
bias
=
bias
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
...
@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Tru
e])
# @pytest.mark.parametrize('causal', [
Fals
e])
#
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
6
4
,
128
])
#
@pytest.mark.parametrize('d', [
9
6])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
91
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
91
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
device
=
'cuda'
...
@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
nheads
=
4
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
if
bias_shape
==
'1h1k'
:
bias
=
torch
.
randn
(
1
,
nheads
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'1hqk'
:
bias
=
torch
.
randn
(
1
,
nheads
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b11k'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b1qk'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
else
:
bias
=
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output_0
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output_0
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
g
=
torch
.
randn_like
(
output_0
)
g
=
torch
.
randn_like
(
output_0
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq
=
False
deterministic_dq
=
False
equal_fn
=
(
torch
.
equal
if
deterministic_dq
# Numerical error if we just do any arithmetic on dq
else
partial
(
torch
.
allclose
,
atol
=
1e-3
if
dtype
==
torch
.
bfloat16
else
1e-5
))
dq_atol
=
((
dq_0
+
0.3
-
0.3
)
-
dq_0
).
abs
().
max
().
item
()
equal_fn
=
torch
.
equal
if
deterministic_dq
else
partial
(
torch
.
allclose
,
atol
=
dq_atol
)
# Run 10000 times and check that the results don't change
for
i
in
range
(
10000
):
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
None
,
causal
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment