Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
50ca2348
Commit
50ca2348
authored
Oct 23, 2022
by
Tri Dao
Browse files
Add Triton implementation for benchmarking
parent
9e92a1f2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
443 additions
and
0 deletions
+443
-0
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+79
-0
flash_attn/triton/fused_attention.py
flash_attn/triton/fused_attention.py
+364
-0
No files found.
benchmarks/benchmark_causal.py
0 → 100644
View file @
50ca2348
from
functools
import
partial
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.triton.fused_attention
import
attention
as
attention
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
False
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
=
rearrange
(
q
,
'b t h d -> (b h) t d'
)
k
=
rearrange
(
k
,
'b s h d -> (b h) d s'
)
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
# Preallocate attn_weights for `baddbmm`
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'(b h) t s -> b h t s'
,
h
=
nheads
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
def
attention_triton
(
q
,
k
,
v
):
"""
No dropout and only support causal=True.
Triton implementation seems to require q, k, v being contiguous?
Arguments:
q, k, v: (batch_size, nheads, seqlen, head_dim)
Output:
output: (batch_size, nheads, seqlen, head_dim)
"""
softmax_scale
=
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
return
attention
(
q
,
k
,
v
,
softmax_scale
)
torch
.
manual_seed
(
0
)
repeats
=
30
batch_size
=
2
seqlen
=
2048
nheads
=
12
headdim
=
128
dropout_p
=
0.0
causal
=
True
dtype
=
torch
.
bfloat16
device
=
'cuda'
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
benchmark_all
(
flash_attn_unpadded_qkvpacked_func
,
rearrange
(
qkv
,
'b s ... -> (b s) ...'
),
cu_seqlens
,
seqlen
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'FlashAttention'
)
benchmark_all
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'PyTorch Attention'
)
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
benchmark_all
(
attention_triton
,
q
,
k
,
v
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
flash_attn/triton/fused_attention.py
0 → 100644
View file @
50ca2348
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking.
# Fixing some dtype casting to make it work for bfloat16
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import
pytest
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
TMP
,
L
,
M
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
Z
,
H
,
N_CTX
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_hz
=
tl
.
program_id
(
1
)
# initialize offsets
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
off_q
=
off_hz
*
stride_qh
+
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:]
*
stride_qk
off_k
=
off_hz
*
stride_qh
+
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:]
*
stride_kk
off_v
=
off_hz
*
stride_qh
+
offs_n
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:]
*
stride_qk
# Initialize pointers to Q, K, V
q_ptrs
=
Q
+
off_q
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
t_ptrs
=
TMP
+
off_hz
*
N_CTX
+
offs_m
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# load q: it will stay in SRAM throughout
q
=
tl
.
load
(
q_ptrs
)
# loop over k, v and update accumulator
for
start_n
in
range
(
0
,
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
*=
sm_scale
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
tl
.
store
(
t_ptrs
,
acc_scale
)
acc_scale
=
tl
.
load
(
t_ptrs
)
# BUG: have to store and immediately load
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vk
)
p
=
p
.
to
(
q
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# rematerialize offsets to save registers
start_m
=
tl
.
program_id
(
0
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# write back l and m
l_ptrs
=
L
+
off_hz
*
N_CTX
+
offs_m
m_ptrs
=
M
+
off_hz
*
N_CTX
+
offs_m
tl
.
store
(
l_ptrs
,
l_i
)
tl
.
store
(
m_ptrs
,
m_i
)
# initialize pointers to output
offs_n
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
off_o
=
off_hz
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
)
@
triton
.
jit
def
_bwd_preprocess
(
Out
,
DO
,
L
,
NewDO
,
Delta
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
):
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
tl
.
arange
(
0
,
D_HEAD
)
# load
o
=
tl
.
load
(
Out
+
off_m
[:,
None
]
*
D_HEAD
+
off_n
[
None
,
:]).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO
+
off_m
[:,
None
]
*
D_HEAD
+
off_n
[
None
,
:]).
to
(
tl
.
float32
)
denom
=
tl
.
load
(
L
+
off_m
).
to
(
tl
.
float32
)
# compute
do
=
do
/
denom
[:,
None
]
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
# write-back
tl
.
store
(
NewDO
+
off_m
[:,
None
]
*
D_HEAD
+
off_n
[
None
,
:],
do
)
tl
.
store
(
Delta
+
off_m
,
delta
)
@
triton
.
jit
def
_bwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
Out
,
DO
,
DQ
,
DK
,
DV
,
L
,
M
,
D
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
Z
,
H
,
N_CTX
,
num_block
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
off_hz
=
tl
.
program_id
(
0
)
off_z
=
off_hz
//
H
off_h
=
off_hz
%
H
# offset pointers for batch/head
Q
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
K
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
V
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
DO
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
DQ
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
DK
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
DV
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
for
start_n
in
range
(
0
,
num_block
):
lo
=
start_n
*
BLOCK_M
# initialize row/col offsets
offs_qm
=
lo
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
start_n
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
# initialize pointers to value-like data
q_ptrs
=
Q
+
(
offs_qm
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:]
*
stride_qk
)
k_ptrs
=
K
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_k
[
None
,
:]
*
stride_kk
)
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:]
*
stride_qk
)
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:]
*
stride_qk
)
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:]
*
stride_qk
)
# pointer to row-wise quantities in value-like data
D_ptrs
=
D
+
off_hz
*
N_CTX
m_ptrs
=
M
+
off_hz
*
N_CTX
# initialize dv amd dk
dv
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# k and v stay in SRAM throughout
k
=
tl
.
load
(
k_ptrs
)
v
=
tl
.
load
(
v_ptrs
)
# loop over rows
for
start_m
in
range
(
lo
,
num_block
*
BLOCK_M
,
BLOCK_M
):
offs_m_curr
=
start_m
+
offs_m
# load q, k, v, do on-chip
q
=
tl
.
load
(
q_ptrs
)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
m
=
tl
.
load
(
m_ptrs
+
offs_m_curr
)
p
=
tl
.
exp
(
qk
*
sm_scale
-
m
[:,
None
])
# compute dv
do
=
tl
.
load
(
do_ptrs
)
dv
+=
tl
.
dot
(
p
.
to
(
q
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
Di
=
tl
.
load
(
D_ptrs
+
offs_m_curr
)
dp
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
-
Di
[:,
None
]
dp
+=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# compute ds = p * (dp - delta[:, None])
ds
=
p
*
dp
*
sm_scale
# compute dk = dot(ds.T, q)
dk
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
q
,
trans_a
=
True
)
# # compute dq
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
# # increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_qm
q_ptrs
+=
BLOCK_M
*
stride_qm
do_ptrs
+=
BLOCK_M
*
stride_qm
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_qm
+
offs_k
[
None
,
:]
*
stride_qk
)
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_k
[
None
,
:]
*
stride_kk
)
tl
.
store
(
dv_ptrs
,
dv
)
tl
.
store
(
dk_ptrs
,
dk
)
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
sm_scale
):
BLOCK
=
128
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
o
=
torch
.
empty_like
(
q
)
grid
=
(
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK
),
q
.
shape
[
0
]
*
q
.
shape
[
1
])
tmp
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
L
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
m
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
tmp
,
L
,
m
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
L
,
m
)
ctx
.
BLOCK
=
BLOCK
ctx
.
grid
=
grid
ctx
.
sm_scale
=
sm_scale
ctx
.
BLOCK_DMODEL
=
Lk
return
o
@
staticmethod
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
l
,
m
=
ctx
.
saved_tensors
do
=
do
.
contiguous
()
dq
=
torch
.
zeros_like
(
q
,
dtype
=
torch
.
float32
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
do_scaled
=
torch
.
empty_like
(
do
)
delta
=
torch
.
empty_like
(
l
)
_bwd_preprocess
[(
ctx
.
grid
[
0
]
*
ctx
.
grid
[
1
],
)](
o
,
do
,
l
,
do_scaled
,
delta
,
BLOCK_M
=
ctx
.
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps
=
8
_bwd_kernel
[(
ctx
.
grid
[
1
],)](
q
,
k
,
v
,
ctx
.
sm_scale
,
o
,
do_scaled
,
dq
,
dk
,
dv
,
l
,
m
,
delta
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
ctx
.
grid
[
0
],
BLOCK_M
=
ctx
.
BLOCK
,
BLOCK_N
=
ctx
.
BLOCK
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
dq
,
dk
,
dv
,
None
attention
=
_attention
.
apply
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
2048
,
64
)])
def
test_op
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
# reference implementation
M
=
torch
.
tril
(
torch
.
ones
((
N_CTX
,
N_CTX
),
device
=
"cuda"
))
p
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
sm_scale
for
z
in
range
(
Z
):
for
h
in
range
(
H
):
p
[:,
:,
M
==
0
]
=
float
(
"-inf"
)
p
=
torch
.
softmax
(
p
.
float
(),
dim
=-
1
).
half
()
ref_out
=
torch
.
matmul
(
p
,
v
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
tri_out
=
attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
triton
.
testing
.
assert_almost_equal
(
ref_out
,
tri_out
)
triton
.
testing
.
assert_almost_equal
(
ref_dv
,
tri_dv
)
triton
.
testing
.
assert_almost_equal
(
ref_dk
,
tri_dk
)
triton
.
testing
.
assert_almost_equal
(
ref_dq
,
tri_dq
)
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_func
HAS_FLASH
=
True
except
BaseException
:
HAS_FLASH
=
False
BATCH
,
N_HEADS
,
N_CTX
,
D_HEAD
=
4
,
48
,
4096
,
64
# vary seq length for fixed head and batch=4
configs
=
[
triton
.
testing
.
Benchmark
(
x_names
=
[
'N_CTX'
],
x_vals
=
[
2
**
i
for
i
in
range
(
10
,
16
)],
line_arg
=
'provider'
,
line_vals
=
[
'triton'
]
+
([
'flash'
]
if
HAS_FLASH
else
[]),
line_names
=
[
'Triton'
]
+
([
'Flash'
]
if
HAS_FLASH
else
[]),
styles
=
[(
'red'
,
'-'
),
(
'blue'
,
'-'
)],
ylabel
=
'ms'
,
plot_name
=
f
'fused-attention-batch
{
BATCH
}
-head
{
N_HEADS
}
-d
{
D_HEAD
}
-
{
mode
}
'
,
args
=
{
'H'
:
N_HEADS
,
'BATCH'
:
BATCH
,
'D_HEAD'
:
D_HEAD
,
'dtype'
:
torch
.
float16
,
'mode'
:
mode
}
)
for
mode
in
[
'bwd'
]]
@
triton
.
testing
.
perf_report
(
configs
)
def
bench_flash_attention
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
mode
,
provider
,
dtype
=
torch
.
float16
,
device
=
"cuda"
):
assert
mode
in
[
'fwd'
,
'bwd'
]
warmup
=
25
rep
=
100
if
provider
==
"triton"
:
q
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
k
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
v
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
sm_scale
=
1.3
fn
=
lambda
:
attention
(
q
,
k
,
v
,
sm_scale
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
if
provider
==
"flash"
:
lengths
=
torch
.
full
((
BATCH
,),
fill_value
=
N_CTX
,
device
=
device
)
cu_seqlens
=
torch
.
zeros
((
BATCH
+
1
,),
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens
[
1
:]
=
lengths
.
cumsum
(
0
)
qkv
=
torch
.
randn
((
BATCH
*
N_CTX
,
3
,
H
,
D_HEAD
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
fn
=
lambda
:
flash_attn_func
(
qkv
,
cu_seqlens
,
0.
,
N_CTX
,
causal
=
True
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment