Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
86862cfd
"docs/vscode:/vscode.git/clone" did not exist on "5e96333cb2637fd5fb1fe76b00946555b491fb6d"
Commit
86862cfd
authored
Nov 04, 2022
by
Tri Dao
Browse files
Implement attention bias for Triton version
parent
470010f5
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
225 additions
and
64 deletions
+225
-64
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+180
-48
tests/test_flash_attn.py
tests/test_flash_attn.py
+45
-16
No files found.
flash_attn/flash_attn_triton.py
View file @
86862cfd
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
86862cfd
...
@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
...
@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
upcast
=
True
,
reorder_ops
=
False
):
dropout_mask
=
None
,
causal
=
False
,
bias
=
None
,
upcast
=
True
,
reorder_ops
=
False
):
"""
"""
Arguments:
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
q: (batch_size, seqlen_q, nheads, head_dim)
...
@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
...
@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
key_padding_mask: (batch_size, seqlen_k)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...
@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
...
@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
/
math
.
sqrt
(
d
),
k
)
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
else
:
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
/
math
.
sqrt
(
d
))
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
/
math
.
sqrt
(
d
))
if
bias
is
not
None
:
scores
=
(
scores
+
bias
).
to
(
dtype
=
scores
.
dtype
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
'b s -> b 1 1 s'
),
float
(
'-inf'
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
'b s -> b 1 1 s'
),
float
(
'-inf'
))
if
causal
:
if
causal
:
...
@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
...
@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
# @pytest.mark.parametrize('d', [4
8
])
# @pytest.mark.parametrize('d', [
6
4])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
def
test_flash_attn_triton_output
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
device
=
'cuda'
...
@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
nheads
=
4
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
if
bias_shape
==
'1h1k'
:
bias
=
torch
.
randn
(
1
,
nheads
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'1hqk'
:
bias
=
torch
.
randn
(
1
,
nheads
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b11k'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b1qk'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
else
:
bias
=
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
)
output_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
bias
=
bias
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
output_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
bias
=
bias
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
...
@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Tru
e])
# @pytest.mark.parametrize('causal', [
Fals
e])
#
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
40
,
48
,
64
,
128
,
80
,
88
,
96
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
6
4
,
128
])
#
@pytest.mark.parametrize('d', [
9
6])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
91
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
113
,
203
),
(
128
,
217
),
(
91
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
@
pytest
.
mark
.
parametrize
(
'bias_shape'
,
([
None
,
'1h1k'
,
'1hqk'
,
'b11k'
,
'b1qk'
]))
def
test_flash_attn_triton_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
,
bias_shape
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
device
=
'cuda'
...
@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
nheads
=
4
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
k
,
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
).
unbind
(
dim
=
2
)
if
bias_shape
==
'1h1k'
:
bias
=
torch
.
randn
(
1
,
nheads
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'1hqk'
:
bias
=
torch
.
randn
(
1
,
nheads
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b11k'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
1
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
elif
bias_shape
==
'b1qk'
:
bias
=
torch
.
randn
(
batch_size
,
1
,
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
float
,
device
=
device
)
else
:
bias
=
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
[
q
,
k
,
v
]]
output_0
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output_0
=
flash_attn_func
(
q
,
k
,
v
,
bias
,
causal
)
g
=
torch
.
randn_like
(
output_0
)
g
=
torch
.
randn_like
(
output_0
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
dq_0
,
dk_0
,
dv_0
=
torch
.
autograd
.
grad
(
output_0
,
(
q
,
k
,
v
),
g
)
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq
=
False
deterministic_dq
=
False
equal_fn
=
(
torch
.
equal
if
deterministic_dq
# Numerical error if we just do any arithmetic on dq
else
partial
(
torch
.
allclose
,
atol
=
1e-3
if
dtype
==
torch
.
bfloat16
else
1e-5
))
dq_atol
=
((
dq_0
+
0.3
-
0.3
)
-
dq_0
).
abs
().
max
().
item
()
equal_fn
=
torch
.
equal
if
deterministic_dq
else
partial
(
torch
.
allclose
,
atol
=
dq_atol
)
# Run 10000 times and check that the results don't change
for
i
in
range
(
10000
):
for
i
in
range
(
10000
):
output
=
flash_attn_func
(
q
,
k
,
v
,
causal
)
output
=
flash_attn_func
(
q
,
k
,
v
,
None
,
causal
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
output_equal
=
torch
.
equal
(
output
,
output_0
)
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
if
not
output_equal
:
# Printing / computing diff sometimes makes the race condition disappear
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_0
).
abs
().
max
().
item
()
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment