Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d0787acc
Commit
d0787acc
authored
Jul 10, 2024
by
Tri Dao
Browse files
Relax dropout_fraction test
parent
dca6d89d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
1 deletion
+1
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-1
No files found.
tests/test_flash_attn.py
View file @
d0787acc
...
@@ -1430,7 +1430,7 @@ def test_flash_attn_varlen_output(
...
@@ -1430,7 +1430,7 @@ def test_flash_attn_varlen_output(
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
25
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
4
)
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment