Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
81e01efd
Commit
81e01efd
authored
Jul 10, 2024
by
Tri Dao
Browse files
More typo fixes
parent
72e27c63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
2 deletions
+16
-2
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+12
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-0
No files found.
flash_attn/flash_attn_interface.py
View file @
81e01efd
...
@@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
...
@@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
return_softmax
,
return_softmax
,
block_table
,
block_table
,
...
@@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
...
@@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
block_table
=
None
,
...
@@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
...
@@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
...
@@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
tests/test_flash_attn.py
View file @
81e01efd
...
@@ -303,6 +303,7 @@ def attention_kvpacked_ref(
...
@@ -303,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -318,6 +319,7 @@ def attention_kvpacked_ref(
...
@@ -318,6 +319,7 @@ def attention_kvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
...
@@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
...
@@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment