Unverified Commit b16c2794 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Expose out in python API (#2)

parent eee8e47c
......@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
......@@ -52,7 +52,7 @@ def _flash_attn_forward(
q,
k,
v,
None,
out,
alibi_slopes,
dropout_p,
softmax_scale,
......@@ -80,6 +80,8 @@ def _flash_attn_varlen_forward(
alibi_slopes,
return_softmax,
block_table,
*,
out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
......@@ -87,7 +89,7 @@ def _flash_attn_varlen_forward(
q,
k,
v,
None,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
......@@ -220,6 +222,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
......@@ -233,6 +237,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
......@@ -284,6 +289,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
......@@ -302,6 +309,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
......@@ -357,6 +365,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -370,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
......@@ -426,6 +436,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -444,6 +455,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
......@@ -505,6 +517,7 @@ class FlashAttnFunc(torch.autograd.Function):
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -518,6 +531,7 @@ class FlashAttnFunc(torch.autograd.Function):
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
......@@ -575,6 +589,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic,
return_softmax,
block_table,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -593,6 +608,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
......@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
......@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)
......@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
......@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)
......@@ -779,6 +801,8 @@ def flash_attn_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
......@@ -839,6 +863,7 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)
......@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
......@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)
......@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
......@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)
......@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
......@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
out=out,
)
......@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
*,
out=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
......@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
cache_batch_idx,
block_table,
alibi_slopes,
None,
out,
softmax_scale,
causal,
window_size[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment