Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b0c0db81
Commit
b0c0db81
authored
Oct 30, 2022
by
Tri Dao
Browse files
Implement FlashAttention in Triton
parent
c422fee3
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
602 additions
and
109 deletions
+602
-109
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+13
-16
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+529
-0
flash_attn/flash_attn_triton_og.py
flash_attn/flash_attn_triton_og.py
+5
-93
tests/test_flash_attn.py
tests/test_flash_attn.py
+55
-0
No files found.
benchmarks/benchmark_causal.py
View file @
b0c0db81
...
...
@@ -6,9 +6,11 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
pytorch_profiler
from
flash_attn.utils.benchmark
import
benchmark_forward
,
benchmark_all
,
pytorch_profiler
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.triton.fused_attention
import
attention
as
attention
# from flash_attn.triton.fused_attention import attention as attention
from
flash_attn.flash_attn_triton
import
flash_attn_qkvpacked_func
from
flash_attn.flash_attn_triton_og
import
attention
as
attention_og
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
...
...
@@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
return
output
.
to
(
dtype
=
qkv
.
dtype
)
def
attention_triton
(
q
,
k
,
v
):
"""
No dropout and only support causal=True.
Triton implementation seems to require q, k, v being contiguous?
Arguments:
q, k, v: (batch_size, nheads, seqlen, head_dim)
Output:
output: (batch_size, nheads, seqlen, head_dim)
"""
softmax_scale
=
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
return
attention
(
q
,
k
,
v
,
softmax_scale
)
def
attention_megatron
(
qkv
):
"""
Arguments:
...
...
@@ -85,6 +74,10 @@ batch_size = 2
seqlen
=
4096
nheads
=
12
headdim
=
128
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p
=
0.0
causal
=
True
dtype
=
torch
.
bfloat16
...
...
@@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
benchmark_all
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'PyTorch Attention'
)
benchmark_all
(
flash_attn_qkvpacked_func
,
qkv
,
causal
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
pytorch_profiler
(
flash_attn_qkvpacked_func
,
qkv
,
causal
,
backward
=
True
)
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
benchmark_all
(
attention_triton
,
q
,
k
,
v
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton'
)
benchmark_all
(
attention_og
,
q
,
k
,
v
,
1.0
,
repeats
=
repeats
,
desc
=
'FlashAttention Triton OG'
)
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
if
scaled_upper_triang_masked_softmax
is
not
None
:
benchmark_all
(
attention_megatron
,
qkv
,
repeats
=
repeats
,
desc
=
'Megatron Attention'
)
flash_attn/flash_attn_triton.py
0 → 100644
View file @
b0c0db81
This diff is collapsed.
Click to expand it.
flash_attn/
triton/fused_attention
.py
→
flash_attn/
flash_attn_triton_og
.py
View file @
b0c0db81
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking.
#
Fixing some
dtype cast
ing
to make it work for bf
loat
16
#
We fixed a few
dtype cast to make it work for bf16
"""
Fused Attention
...
...
@@ -78,7 +78,7 @@ def _fwd_kernel(
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vk
)
p
=
p
.
to
(
q
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
...
...
@@ -178,7 +178,7 @@ def _bwd_kernel(
p
=
tl
.
exp
(
qk
*
sm_scale
-
m
[:,
None
])
# compute dv
do
=
tl
.
load
(
do_ptrs
)
dv
+=
tl
.
dot
(
p
.
to
(
q
.
dtype
),
do
,
trans_a
=
True
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
Di
=
tl
.
load
(
D_ptrs
+
offs_m_curr
)
dp
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
-
Di
[:,
None
]
...
...
@@ -189,7 +189,7 @@ def _bwd_kernel(
dk
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
q
,
trans_a
=
True
)
# # compute dq
dq
=
tl
.
load
(
dq_ptrs
,
eviction_policy
=
"evict_last"
)
dq
+=
tl
.
dot
(
ds
.
to
(
q
.
dtype
),
k
)
dq
+=
tl
.
dot
(
ds
.
to
(
k
.
dtype
),
k
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
# # increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_qm
...
...
@@ -270,95 +270,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
dq
,
dk
,
dv
,
None
return
dq
.
to
(
q
.
dtype
)
,
dk
,
dv
,
None
attention
=
_attention
.
apply
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
2048
,
64
)])
def
test_op
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
# reference implementation
M
=
torch
.
tril
(
torch
.
ones
((
N_CTX
,
N_CTX
),
device
=
"cuda"
))
p
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
sm_scale
for
z
in
range
(
Z
):
for
h
in
range
(
H
):
p
[:,
:,
M
==
0
]
=
float
(
"-inf"
)
p
=
torch
.
softmax
(
p
.
float
(),
dim
=-
1
).
half
()
ref_out
=
torch
.
matmul
(
p
,
v
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
tri_out
=
attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
triton
.
testing
.
assert_almost_equal
(
ref_out
,
tri_out
)
triton
.
testing
.
assert_almost_equal
(
ref_dv
,
tri_dv
)
triton
.
testing
.
assert_almost_equal
(
ref_dk
,
tri_dk
)
triton
.
testing
.
assert_almost_equal
(
ref_dq
,
tri_dq
)
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_func
HAS_FLASH
=
True
except
BaseException
:
HAS_FLASH
=
False
BATCH
,
N_HEADS
,
N_CTX
,
D_HEAD
=
4
,
48
,
4096
,
64
# vary seq length for fixed head and batch=4
configs
=
[
triton
.
testing
.
Benchmark
(
x_names
=
[
'N_CTX'
],
x_vals
=
[
2
**
i
for
i
in
range
(
10
,
16
)],
line_arg
=
'provider'
,
line_vals
=
[
'triton'
]
+
([
'flash'
]
if
HAS_FLASH
else
[]),
line_names
=
[
'Triton'
]
+
([
'Flash'
]
if
HAS_FLASH
else
[]),
styles
=
[(
'red'
,
'-'
),
(
'blue'
,
'-'
)],
ylabel
=
'ms'
,
plot_name
=
f
'fused-attention-batch
{
BATCH
}
-head
{
N_HEADS
}
-d
{
D_HEAD
}
-
{
mode
}
'
,
args
=
{
'H'
:
N_HEADS
,
'BATCH'
:
BATCH
,
'D_HEAD'
:
D_HEAD
,
'dtype'
:
torch
.
float16
,
'mode'
:
mode
}
)
for
mode
in
[
'bwd'
]]
@
triton
.
testing
.
perf_report
(
configs
)
def
bench_flash_attention
(
BATCH
,
H
,
N_CTX
,
D_HEAD
,
mode
,
provider
,
dtype
=
torch
.
float16
,
device
=
"cuda"
):
assert
mode
in
[
'fwd'
,
'bwd'
]
warmup
=
25
rep
=
100
if
provider
==
"triton"
:
q
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
k
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
v
=
torch
.
randn
((
BATCH
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
sm_scale
=
1.3
fn
=
lambda
:
attention
(
q
,
k
,
v
,
sm_scale
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
if
provider
==
"flash"
:
lengths
=
torch
.
full
((
BATCH
,),
fill_value
=
N_CTX
,
device
=
device
)
cu_seqlens
=
torch
.
zeros
((
BATCH
+
1
,),
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens
[
1
:]
=
lengths
.
cumsum
(
0
)
qkv
=
torch
.
randn
((
BATCH
*
N_CTX
,
3
,
H
,
D_HEAD
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
fn
=
lambda
:
flash_attn_func
(
qkv
,
cu_seqlens
,
0.
,
N_CTX
,
causal
=
True
)
if
mode
==
'bwd'
:
o
=
fn
()
do
=
torch
.
randn_like
(
o
)
fn
=
lambda
:
o
.
backward
(
do
,
retain_graph
=
True
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
percentiles
=
None
,
warmup
=
warmup
,
rep
=
rep
)
return
ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
tests/test_flash_attn.py
View file @
b0c0db81
...
...
@@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
'b s -> b s 1 1'
),
0.0
)
...
...
@@ -849,3 +851,56 @@ def test_flash_attn_multigpu():
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
from
flash_attn.flash_attn_triton
import
flash_attn_func
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
g
=
torch
.
randn_like
(
output
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
dq_pt
,
dk_pt
,
dv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
'dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment