Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6df7e0a0
Unverified
Commit
6df7e0a0
authored
Jul 04, 2024
by
muoshuosha
Committed by
GitHub
Jul 03, 2024
Browse files
Fix the varlen deterministic test (#1023)
Co-authored-by:
moshuosha
<
moshuosha@qq.com
>
parent
9486635c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-4
No files found.
tests/test_flash_attn.py
View file @
6df7e0a0
...
...
@@ -2459,9 +2459,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
dq
0
,
dk
0
,
dv
0
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
for
_
in
range
(
50
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
assert
torch
.
equal
(
dv
,
dv
)
assert
torch
.
equal
(
dk
,
dk
)
assert
torch
.
equal
(
dq
,
dq
)
assert
torch
.
equal
(
dv
,
dv
0
)
assert
torch
.
equal
(
dk
,
dk
0
)
assert
torch
.
equal
(
dq
,
dq
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment