Unverified Commit 4e8d6006 authored by Jianwei Dong's avatar Jianwei Dong Committed by GitHub
Browse files

Add the return_softmax_lse parameter to the flash_attn_with_kvcache function...

Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989)
parent 6df7e0a0
...@@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache( ...@@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache(
rotary_interleaved=True, rotary_interleaved=True,
alibi_slopes=None, alibi_slopes=None,
num_splits=0, num_splits=0,
return_softmax_lse=False,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache( ...@@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits. to automatically determine the number of splits.
Don't change this unless you know what you are doing. Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
""" """
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
...@@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache( ...@@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache(
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
) )
return out return (out, softmax_lse) if return_softmax_lse else out
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment