Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b0c0db81
Commit
b0c0db81
authored
Oct 30, 2022
by
Tri Dao
Browse files
Implement FlashAttention in Triton
parent
c422fee3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
602 additions
and
109 deletions
+602
-109
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+13
-16
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+529
-0
flash_attn/flash_attn_triton_og.py
flash_attn/flash_attn_triton_og.py
+5
-93
tests/test_flash_attn.py
tests/test_flash_attn.py
+55
-0
No files found.
benchmarks/benchmark_causal.py
View file @
b0c0db81
...
...
@@ -6,9 +6,11 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
pytorch_profiler
from
flash_attn.utils.benchmark
import
benchmark_forward
,
benchmark_all
,
pytorch_profiler
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.triton.fused_attention
import
attention
as
attention
# from flash_attn.triton.fused_attention import attention as attention
from
flash_attn.flash_attn_triton
import
flash_attn_qkvpacked_func
from
flash_attn.flash_attn_triton_og
import
attention
as
attention_og
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
...
...
@@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
return
output
.
to
(
dtype
=
qkv
.
dtype
)
def
attention_triton
(
q
,
k
,
v
):
"""
No dropout and only support causal=True.
Triton implementation seems to require q, k, v being contiguous?
Arguments:
q, k, v: (batch_size, nheads, seqlen, head_dim)
Output:
output: (batch_size, nheads, seqlen, head_dim)
"""
softmax_scale
=
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
return
attention
(
q
,
k
,
v
,
softmax_scale
)
def
attention_megatron
(
qkv
):
"""
Arguments:
...
...
@@ -85,6 +74,10 @@ batch_size = 2
seqlen
=
4096
nheads
=
12
headdim
=
128
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p
=
0.0
causal
=
True
dtype
=
torch
.
bfloat16
...
...
@@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
benchmark_all
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'PyTorch Attention'
)
benchmark_all
(
flash_attn_qkvpacked_func
,
qkv
,
causal
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
pytorch_profiler
(
flash_attn_qkvpacked_func
,
qkv
,
causal
,
backward
=
True
)
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
benchmark_all
(
attention_triton
,
q
,
k
,
v
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
benchmark_all
(
attention_og
,
q
,
k
,
v
,
1.0
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton OG'
)
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
if
scaled_upper_triang_masked_softmax
is
not
None
:
benchmark_all
(
attention_megatron
,
qkv
,
repeats
=
repeats
,
desc
=
'Megatron Attention'
)
flash_attn/flash_attn_triton.py
0 → 100644
View file @
b0c0db81
"""
Based on the FlashAttention implementation from Phil Tillet.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes:
- Support both causal and non-causal attention.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling.
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
"""
import
math
import
torch
from
einops
import
rearrange
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
)
@
triton
.
heuristics
(
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
(
args
[
"BLOCK_N"
])
==
0
,
}
)
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
Out
,
Lse
,
TMP
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_ob
,
stride_oh
,
stride_om
,
nheads
,
seqlen_q
,
seqlen_k
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_hb
=
tl
.
program_id
(
1
)
off_b
=
off_hb
//
nheads
off_h
=
off_hb
%
nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs
=
Q
+
off_b
*
stride_qb
+
off_h
*
stride_qh
+
(
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:])
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
# initialize pointer to m and l
t_ptrs
=
TMP
+
off_hb
*
seqlen_q
+
offs_m
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# load q: it will stay in SRAM throughout
if
EVEN_M
:
q
=
tl
.
load
(
q_ptrs
)
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
# loop over k, v and update accumulator
end_n
=
seqlen_k
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
seqlen_k
)
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
if
EVEN_N
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
not
EVEN_N
:
qk
+=
tl
.
where
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
,
0
,
float
(
"-inf"
))
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
)[
None
,
:],
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
)
*
softmax_scale
,
lse_i
)
# Slightly faster to multiply the softmax_scale here since the compiler can then
# fuse the mult and add into an fma instruction.
p
=
tl
.
exp
(
qk
*
softmax_scale
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# scale acc_o
acc_o_scale
=
tl
.
exp
(
m_i
-
m_ij
)
# # -- update output accumulator --
# BUG: have to store and immediately load
tl
.
store
(
t_ptrs
,
acc_o_scale
)
acc_o_scale
=
tl
.
load
(
t_ptrs
)
acc_o
=
acc_o
*
acc_o_scale
[:,
None
]
# update acc_o
if
EVEN_N
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc_o
+=
tl
.
dot
(
p
,
v
)
# -- update statistics
m_i
=
m_ij
l_i_new
=
tl
.
exp
(
lse_i
-
m_ij
)
+
l_ij
lse_i
=
m_ij
+
tl
.
log
(
l_i_new
)
o_scale
=
tl
.
exp
(
m_i
-
lse_i
)
# BUG: have to store and immediately load
tl
.
store
(
t_ptrs
,
o_scale
)
o_scale
=
tl
.
load
(
t_ptrs
)
acc_o
=
acc_o
*
o_scale
[:,
None
]
# rematerialize offsets to save registers
start_m
=
tl
.
program_id
(
0
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# write back l and m
lse_ptrs
=
Lse
+
off_hb
*
seqlen_q
+
offs_m
tl
.
store
(
lse_ptrs
,
lse_i
)
# initialize pointers to output
offs_n
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
out_ptrs
=
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
(
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:])
if
EVEN_M
:
tl
.
store
(
out_ptrs
,
acc_o
)
else
:
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
)
@
triton
.
heuristics
(
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
}
)
@
triton
.
jit
def
_bwd_preprocess_do_o_dot
(
Out
,
DO
,
Delta
,
stride_ob
,
stride_oh
,
stride_om
,
stride_dob
,
stride_doh
,
stride_dom
,
nheads
,
seqlen_q
,
seqlen_q_rounded
,
EVEN_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_hb
=
tl
.
program_id
(
1
)
off_b
=
off_hb
//
nheads
off_h
=
off_hb
%
nheads
# initialize offsets
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
# load
if
EVEN_M
:
o
=
tl
.
load
(
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO
+
off_b
*
stride_dob
+
off_h
*
stride_doh
+
offs_m
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:]).
to
(
tl
.
float32
)
else
:
o
=
tl
.
load
(
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:],
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO
+
off_b
*
stride_dob
+
off_h
*
stride_doh
+
offs_m
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:],
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
).
to
(
tl
.
float32
)
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
# write-back
tl
.
store
(
Delta
+
off_hb
*
seqlen_q_rounded
+
offs_m
,
delta
)
@
triton
.
jit
def
_bwd_kernel_one_col_block
(
start_n
,
Q
,
K
,
V
,
softmax_scale
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
ATOMIC_ADD
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
begin_m
=
0
if
not
IS_CAUSAL
else
((
start_n
*
BLOCK_N
)
//
BLOCK_M
)
*
BLOCK_M
# initialize row/col offsets
offs_qm
=
begin_m
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
# initialize pointers to value-like data
q_ptrs
=
Q
+
(
offs_qm
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:])
k_ptrs
=
K
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_k
[
None
,
:])
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_k
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_k
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_k
[
None
,
:])
# initialize dv amd dk
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
k
=
tl
.
load
(
k_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
# loop over rows
num_block_m
=
tl
.
cdiv
(
seqlen_q
,
BLOCK_M
)
for
start_m
in
range
(
begin_m
,
num_block_m
*
BLOCK_M
,
BLOCK_M
):
start_m
=
tl
.
multiple_of
(
start_m
,
BLOCK_M
)
offs_m_curr
=
start_m
+
offs_m
# load q, k, v, do on-chip
q
=
tl
.
load
(
q_ptrs
)
# recompute p = softmax(qk, dim=-1).T
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
# compute dv
do
=
tl
.
load
(
do_ptrs
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
Di
=
tl
.
load
(
D
+
offs_m_curr
)
# Converting ds to q.dtype here reduces register pressure and makes it much faster
# for BLOCK_HEADDIM=128
ds
=
(
p
*
(
dp
-
Di
[:,
None
])
*
softmax_scale
).
to
(
q
.
dtype
)
# compute dk = dot(ds.T, q)
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
# compute dq
if
not
ATOMIC_ADD
:
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
else
:
# If we're parallelizing across the seqlen_k dimension
dq
=
tl
.
dot
(
ds
,
k
)
tl
.
atomic_add
(
dq_ptrs
,
dq
)
# increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_dqm
q_ptrs
+=
BLOCK_M
*
stride_qm
do_ptrs
+=
BLOCK_M
*
stride_dom
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_k
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_k
[
None
,
:])
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
def
init_to_zero
(
name
):
# def fn(nargs):
# with torch.no_grad():
# nargs[name].zero_()
# return fn
return
lambda
nargs
:
nargs
[
name
].
zero_
()
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
# Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
4
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
4
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
],
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
],
# reset_to_zero=['DQ']
)
@
triton
.
jit
def
_bwd_kernel
(
Q
,
K
,
V
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_dob
,
stride_doh
,
stride_dom
,
stride_dqb
,
stride_dqh
,
stride_dqm
,
stride_dkb
,
stride_dkh
,
stride_dkn
,
stride_dvb
,
stride_dvh
,
stride_dvn
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
off_hb
=
tl
.
program_id
(
1
)
off_b
=
off_hb
//
nheads
off_h
=
off_hb
%
nheads
# offset pointers for batch/head
Q
+=
off_b
*
stride_qb
+
off_h
*
stride_qh
K
+=
off_b
*
stride_kb
+
off_h
*
stride_kh
V
+=
off_b
*
stride_vb
+
off_h
*
stride_vh
DO
+=
off_b
*
stride_dob
+
off_h
*
stride_doh
DQ
+=
off_b
*
stride_dqb
+
off_h
*
stride_dqh
DK
+=
off_b
*
stride_dkb
+
off_h
*
stride_dkh
DV
+=
off_b
*
stride_dvb
+
off_h
*
stride_dvh
# pointer to row-wise quantities in value-like data
D
+=
off_hb
*
seqlen_q_rounded
LSE
+=
off_hb
*
seqlen_q_rounded
if
not
SEQUENCE_PARALLEL
:
num_block_n
=
tl
.
cdiv
(
seqlen_k
,
BLOCK_N
)
for
start_n
in
range
(
0
,
num_block_n
):
_bwd_kernel_one_col_block
(
start_n
,
Q
,
K
,
V
,
softmax_scale
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
ATOMIC_ADD
=
False
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
)
else
:
start_n
=
tl
.
program_id
(
0
)
_bwd_kernel_one_col_block
(
start_n
,
Q
,
K
,
V
,
softmax_scale
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
ATOMIC_ADD
=
True
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
)
def
_flash_attn_forward
(
q
,
k
,
v
,
causal
=
False
,
softmax_scale
=
None
):
# shape constraints
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
v
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
d
in
{
16
,
32
,
64
,
128
}
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
'All tensors must have the same type'
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
'Only support fp16 and bf16'
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
lse
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device,
# dtype=torch.float32)
tmp
=
torch
.
empty
((
batch
,
nheads
,
seqlen_q_rounded
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
o
=
torch
.
empty_like
(
q
)
# BLOCK = 128
# num_warps = 4 if d <= 64 else 8
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_fwd_kernel
[
grid
](
q
,
k
,
v
,
o
,
lse
,
tmp
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal
,
d
,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps,
# num_stages=1,
)
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
def
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
causal
=
False
,
softmax_scale
=
None
):
# Make sure that the last dimension is contiguous
if
do
.
stride
(
-
1
)
!=
1
:
do
=
do
.
contiguous
()
batch
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
seqlen_q
%
128
==
0
,
'Backward pass currently only support seqlen that are multiples of 128'
assert
seqlen_k
%
128
==
0
,
'Backward pass currently only support seqlen that are multiples of 128'
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
assert
lse
.
shape
==
(
batch
,
nheads
,
seqlen_q_rounded
)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum
=
torch
.
empty_like
(
q
,
dtype
=
torch
.
float32
)
delta
=
torch
.
empty_like
(
lse
)
# delta = torch.zeros_like(lse)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_bwd_preprocess_do_o_dot
[
grid
](
o
,
do
,
delta
,
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_q_rounded
,
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
d
,
)
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
# BLOCK_M = 128
# BLOCK_N = 64
# num_warps = 4
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_k
,
META
[
"BLOCK_N"
])
if
META
[
"SEQUENCE_PARALLEL"
]
else
1
,
batch
*
nheads
)
_bwd_kernel
[
grid
](
q
,
k
,
v
,
do
,
dq_accum
,
dk
,
dv
,
lse
,
delta
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
dq_accum
.
stride
(
0
),
dq_accum
.
stride
(
2
),
dq_accum
.
stride
(
1
),
dk
.
stride
(
0
),
dk
.
stride
(
2
),
dk
.
stride
(
1
),
dv
.
stride
(
0
),
dv
.
stride
(
2
),
dv
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal
,
d
,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps,
# num_stages=1,
)
dq
.
copy_
(
dq_accum
)
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
causal
=
False
,
softmax_scale
=
None
):
"""
qkv: (batch, seqlen, 3, nheads, headdim)
"""
# Make sure that the last dimension is contiguous
if
qkv
.
stride
(
-
1
)
!=
1
:
qkv
=
qkv
.
contiguous
()
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
causal
=
causal
,
softmax_scale
=
softmax_scale
)
ctx
.
save_for_backward
(
qkv
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
qkv
,
o
,
lse
=
ctx
.
saved_tensors
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dqkv
,
None
,
None
flash_attn_qkvpacked_func
=
FlashAttnQKVPackedFunc
.
apply
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
causal
=
False
,
softmax_scale
=
None
):
"""
q: (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
"""
# Make sure that the last dimension is contiguous
q
,
kv
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
kv
]]
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
causal
=
causal
,
softmax_scale
=
softmax_scale
)
ctx
.
save_for_backward
(
q
,
kv
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
kv
,
o
,
lse
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
do
,
q
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
o
,
lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dkv
,
None
,
None
flash_attn_kvpacked_func
=
FlashAttnKVPackedFunc
.
apply
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
causal
=
False
,
softmax_scale
=
None
):
"""
q, k, v: (batch_size, seqlen, nheads, headdim)
"""
# Make sure that the last dimension is contiguous
q
,
k
,
v
=
[
x
if
x
.
stride
(
-
1
)
==
1
else
x
.
contiguous
()
for
x
in
[
q
,
k
,
v
]]
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
q
,
k
,
v
,
causal
=
causal
,
softmax_scale
=
softmax_scale
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
lse
)
ctx
.
causal
=
causal
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
return
dq
,
dk
,
dv
,
None
,
None
flash_attn_func
=
FlashAttnFunc
.
apply
flash_attn/
triton/fused_attention
.py
→
flash_attn/
flash_attn_triton_og
.py
View file @
b0c0db81
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking.
#
Fixing some
dtype cast
ing
to make it work for bf
loat
16
#
We fixed a few
dtype cast to make it work for bf16
"""
Fused Attention
...
...
@@ -78,7 +78,7 @@ def _fwd_kernel(
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vk
)
p
=
p
.
to
(
q
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
...
...
@@ -178,7 +178,7 @@ def _bwd_kernel(
p
=
tl
.
exp
(
qk
*
sm_scale
-
m
[:,
None
])
# compute dv
do
=
tl
.
load
(
do_ptrs
)
dv
+=
tl
.
dot
(
p
.
to
(
q
.
dtype
),
do
,
trans_a
=
True
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
Di
=
tl
.
load
(
D_ptrs
+
offs_m_curr
)
dp
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
-
Di
[:,
None
]
...
...
@@ -189,7 +189,7 @@ def _bwd_kernel(
dk
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
q
,
trans_a
=
True
)
# # compute dq
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
k
)
dq
+=
tl
.
dot
(
ds
.
to
(
k
.
dtype
),
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
# # increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_qm
...
...
@@ -270,95 +270,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
dq
,
dk
,
dv
,
None
return
dq
.
to
(
q
.
dtype
)
,
dk
,
dv
,
None
attention
=
_attention
.
apply
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
2048
,
64
)])
def
test_op
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
# reference implementation
M
=
torch
.
tril
(
torch
.
ones
((
N_CTX
,
N_CTX
),
device
=
"cuda"
))
p
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
sm_scale
for
z
in
range
(
Z
):
for
h
in
range
(
H
):
p
[:,
:,
M
==
0
]
=
float
(
"-inf"
)
p
=
torch
.
softmax
(
p
.
float
(),
dim
=-
1
).
half
()
ref_out
=
torch
.
matmul
(
p
,
v
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
tri_out
=
attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
triton
.
testing
.
assert_almost_equal
(
ref_out
,
tri_out
)
triton
.
testing
.
assert_almost_equal
(
ref_dv
,
tri_dv
)
triton
.
testing
.
assert_almost_equal
(
ref_dk
,
tri_dk
)
triton
.
testing
.
assert_almost_equal
(
ref_dq
,
tri_dq
)
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_func
HAS_FLASH
=
True
except
BaseException
:
HAS_FLASH
=
False
BATCH
,
N_HEADS
,
N_CTX
,
D_HEAD
=
4
,
48
,
4096
,
64
# vary seq length for fixed head and batch=4
configs
=
[
triton
.
testing
.
Benchmark
(
x_names
=
[
'N_CTX'
],
x_vals
=
[
2
**
i
for
i
in
range
(
10
,
16
)],
line_arg
=
'provider'
,
line_vals
=
[
'triton'
]
+
([
'flash'
]
if
HAS_FLASH
else
[]),
line_names
=
[
'Triton'
]
+
([
'Flash'
]
if
HAS_FLASH
else
[]),
styles
=
[(
'red'
,
'-'
),
(
'blue'
,
'-'
)],
ylabel
=
'ms'
,
plot_name
=
f
'fused-attention-batch
{
BATCH
}
-head
{
N_HEADS
}
-d
{
D_HEAD
}
-
{
mode
}
'
,
args
=
{
'H'
:
N_HEADS
,
'BATCH'
:
BATCH
,
'D_HEAD'
:
D_HEAD
,
'dtype'
:
torch
.
float16
,
'mode'
:
mode
}
)
for
mode
in
[
'bwd'
]]
@
triton
.
testing
.
perf_report
(
configs
)
def
bench_flash_attention
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
mode
,
provider
,
dtype
=
torch
.
float16
,
device
=
"cuda"
):
assert
mode
in
[
'fwd'
,
'bwd'
]
warmup
=
25
rep
=
100
if
provider
==
"triton"
:
q
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
k
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
v
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
sm_scale
=
1.3
fn
=
lambda
:
attention
(
q
,
k
,
v
,
sm_scale
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
if
provider
==
"flash"
:
lengths
=
torch
.
full
((
BATCH
,),
fill_value
=
N_CTX
,
device
=
device
)
cu_seqlens
=
torch
.
zeros
((
BATCH
+
1
,),
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens
[
1
:]
=
lengths
.
cumsum
(
0
)
qkv
=
torch
.
randn
((
BATCH
*
N_CTX
,
3
,
H
,
D_HEAD
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
fn
=
lambda
:
flash_attn_func
(
qkv
,
cu_seqlens
,
0.
,
N_CTX
,
causal
=
True
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
tests/test_flash_attn.py
View file @
b0c0db81
...
...
@@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
'b s -> b s 1 1'
),
0.0
)
...
...
@@ -849,3 +851,56 @@ def test_flash_attn_multigpu():
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
from
flash_attn.flash_attn_triton
import
flash_attn_func
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
g
=
torch
.
randn_like
(
output
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
dq_pt
,
dk_pt
,
dv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment