Unverified Commit 4e8d6006 authored by Jianwei Dong's avatar Jianwei Dong Committed by GitHub
Browse files

Add the return_softmax_lse parameter to the flash_attn_with_kvcache function...

Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989)
parent 6df7e0a0
......@@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache(
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
......@@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
......@@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache(
rotary_interleaved,
num_splits,
)
return out
return (out, softmax_lse) if return_softmax_lse else out
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment