Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0c01568d
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "96d7c3e99fd9e6c5a70d73051576b89478ead098"
Commit
0c01568d
authored
Oct 04, 2022
by
Tri Dao
Browse files
Only run backward test for d=128 on A100
parent
8166063a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
tests/test_flash_attn.py
tests/test_flash_attn.py
+10
-8
No files found.
tests/test_flash_attn.py
View file @
0c01568d
...
@@ -12,6 +12,7 @@ from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
...
@@ -12,6 +12,7 @@ from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
is_sm75
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
7
,
5
)
is_sm75
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
7
,
5
)
is_sm80
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
8
,
0
)
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
'random'
):
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
'random'
):
...
@@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
...
@@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
...
@@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
...
@@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...
@@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dq_unpad
,
dkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
kv_unpad
),
g
)
dq_unpad
,
dkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
kv_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
...
@@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dkv
-
dkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dkv_pt
-
dkv_ref
).
abs
().
max
().
item
()
assert
(
dkv
-
dkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dkv_pt
-
dkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
...
@@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
...
@@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
...
@@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
...
@@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
...
@@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
S_dmask_0
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask_0
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
)
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output_unpad_0
)
g
=
torch
.
randn_like
(
output_unpad_0
)
dq_unpad_0
,
dk_unpad_0
,
dv_unpad_0
,
=
torch
.
autograd
.
grad
(
output_unpad_0
,
dq_unpad_0
,
dk_unpad_0
,
dv_unpad_0
,
=
torch
.
autograd
.
grad
(
output_unpad_0
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
...
@@ -659,7 +661,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
...
@@ -659,7 +661,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
# assert torch.equal(sm_lse, sm_lse_0)
# assert torch.equal(sm_lse, sm_lse_0)
assert
torch
.
equal
(
S_dmask_converted
,
S_dmask_converted_0
)
assert
torch
.
equal
(
S_dmask_converted
,
S_dmask_converted_0
)
if
not
(
is_sm75
and
d
==
128
):
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output_unpad
,
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output_unpad
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
assert
torch
.
equal
(
dq_unpad
,
dq_unpad_0
)
assert
torch
.
equal
(
dq_unpad
,
dq_unpad_0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment