Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f4628b43
Unverified
Commit
f4628b43
authored
Jul 09, 2024
by
Phil Wang
Committed by
GitHub
Jul 09, 2024
Browse files
missing commas and backwards return arguments (#1032)
* missing commas * another fix
parent
8f873cc6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+6
-6
No files found.
flash_attn/flash_attn_interface.py
View file @
f4628b43
...
...
@@ -286,7 +286,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -511,7 +511,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -572,7 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -580,7 +580,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -659,7 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -667,7 +667,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment