Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3d41db3e
Commit
3d41db3e
authored
Jul 10, 2024
by
Tri Dao
Browse files
Only test backward if there's no softcapping
parent
908511b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-4
No files found.
tests/test_flash_attn.py
View file @
3d41db3e
...
@@ -1051,7 +1051,7 @@ def test_flash_attn_output(
...
@@ -1051,7 +1051,7 @@ def test_flash_attn_output(
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq
,
dq
,
...
@@ -1107,7 +1107,7 @@ def test_flash_attn_output(
...
@@ -1107,7 +1107,7 @@ def test_flash_attn_output(
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -1365,7 +1365,7 @@ def test_flash_attn_varlen_output(
...
@@ -1365,7 +1365,7 @@ def test_flash_attn_varlen_output(
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq_unpad
,
dq_unpad
,
...
@@ -1424,7 +1424,7 @@ def test_flash_attn_varlen_output(
...
@@ -1424,7 +1424,7 @@ def test_flash_attn_varlen_output(
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment