Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b16c2794
Unverified
Commit
b16c2794
authored
May 22, 2024
by
Antoni Baum
Committed by
GitHub
May 22, 2024
Browse files
Expose out in python API (#2)
parent
eee8e47c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
4 deletions
+40
-4
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+40
-4
No files found.
vllm_flash_attn/flash_attn_interface.py
View file @
b16c2794
...
...
@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
...
...
@@ -52,7 +52,7 @@ def _flash_attn_forward(
q
,
k
,
v
,
None
,
out
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
...
...
@@ -80,6 +80,8 @@ def _flash_attn_varlen_forward(
alibi_slopes
,
return_softmax
,
block_table
,
*
,
out
=
None
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
...
...
@@ -87,7 +89,7 @@ def _flash_attn_varlen_forward(
q
,
k
,
v
,
None
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
None
,
...
...
@@ -220,6 +222,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -233,6 +237,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -284,6 +289,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -302,6 +309,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -357,6 +365,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -370,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -426,6 +436,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -444,6 +455,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
...
@@ -505,6 +517,7 @@ class FlashAttnFunc(torch.autograd.Function):
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -518,6 +531,7 @@ class FlashAttnFunc(torch.autograd.Function):
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -575,6 +589,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic
,
return_softmax
,
block_table
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -593,6 +608,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
...
...
@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
=
out
,
)
...
...
@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
=
out
,
)
...
...
@@ -779,6 +801,8 @@ def flash_attn_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...
...
@@ -839,6 +863,7 @@ def flash_attn_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
=
out
,
)
...
...
@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
=
out
,
)
...
...
@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
...
...
@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
=
out
,
)
...
...
@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
deterministic
=
False
,
return_attn_probs
=
False
,
block_table
=
None
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
...
...
@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
deterministic
,
return_attn_probs
,
block_table
,
out
=
out
,
)
...
...
@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
rotary_interleaved
=
True
,
alibi_slopes
=
None
,
num_splits
=
0
,
*
,
out
=
None
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
...
@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
cache_batch_idx
,
block_table
,
alibi_slopes
,
None
,
out
,
softmax_scale
,
causal
,
window_size
[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment