Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
13 deletions
+79
-13
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+79
-13
No files found.
vllm_flash_attn/flash_attn_interface.py
View file @
d562aa63
...
...
@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
...
...
@@ -60,6 +60,7 @@ def _flash_attn_forward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax
,
None
,
)
...
...
@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
block_table
,
...
...
@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax
,
None
,
)
...
...
@@ -125,13 +128,19 @@ def _flash_attn_backward(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
flash_attn_cuda
.
bwd
(
dout
,
q
,
k
,
...
...
@@ -147,6 +156,7 @@ def _flash_attn_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
...
...
@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
flash_attn_cuda
.
varlen_bwd
(
dout
,
q
,
k
,
...
...
@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
...
...
@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
...
...
@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
...
...
@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
...
...
@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
...
...
@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
...
...
@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
...
...
@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
...
...
@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
...
...
@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
...
@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
...
...
@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
...
...
@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# <=0.0 means deactivate
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
...
@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -796,6 +843,7 @@ def flash_attn_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -858,6 +906,7 @@ def flash_attn_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
...
@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
...
@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
...
@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
...
...
@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
...
...
@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
...
@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func(
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
...
...
@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache(
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
alibi_slopes
=
None
,
num_splits
=
0
,
return_softmax_lse
=
False
,
*
,
out
=
None
,
):
...
...
@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...
...
@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
...
...
@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache(
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
rotary_interleaved
,
num_splits
,
)
return
out
return
(
out
,
softmax_lse
)
if
return_softmax_lse
else
out
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment