Unverified Commit b16c2794 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Expose out in python API (#2)

parent eee8e47c
...@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): ...@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward( def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
...@@ -52,7 +52,7 @@ def _flash_attn_forward( ...@@ -52,7 +52,7 @@ def _flash_attn_forward(
q, q,
k, k,
v, v,
None, out,
alibi_slopes, alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
...@@ -80,6 +80,8 @@ def _flash_attn_varlen_forward( ...@@ -80,6 +80,8 @@ def _flash_attn_varlen_forward(
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table, block_table,
*,
out=None
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
...@@ -87,7 +89,7 @@ def _flash_attn_varlen_forward( ...@@ -87,7 +89,7 @@ def _flash_attn_varlen_forward(
q, q,
k, k,
v, v,
None, out,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
...@@ -220,6 +222,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -220,6 +222,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
*,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
...@@ -233,6 +237,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -233,6 +237,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -284,6 +289,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -284,6 +289,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
*,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
...@@ -302,6 +309,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -302,6 +309,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None, block_table=None,
out=out,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -357,6 +365,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -357,6 +365,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -370,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -370,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -426,6 +436,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -426,6 +436,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -444,6 +455,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -444,6 +455,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None, block_table=None,
out=out,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
...@@ -505,6 +517,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -505,6 +517,7 @@ class FlashAttnFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -518,6 +531,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -518,6 +531,7 @@ class FlashAttnFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -575,6 +589,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -575,6 +589,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic, deterministic,
return_softmax, return_softmax,
block_table, block_table,
out=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -593,6 +608,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -593,6 +608,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=block_table, block_table=block_table,
out=out,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
...@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func( ...@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
...@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func( ...@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
out=out,
) )
...@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func( ...@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than If K, V are already stacked into 1 tensor, this function will be faster than
...@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func( ...@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
out=out,
) )
...@@ -779,6 +801,8 @@ def flash_attn_func( ...@@ -779,6 +801,8 @@ def flash_attn_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...@@ -839,6 +863,7 @@ def flash_attn_func( ...@@ -839,6 +863,7 @@ def flash_attn_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
out=out,
) )
...@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
...@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
out=out,
) )
...@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func( ...@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than If K, V are already stacked into 1 tensor, this function will be faster than
...@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
out=out,
) )
...@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func( ...@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
block_table=None, block_table=None,
*,
out=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
...@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func( ...@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
deterministic, deterministic,
return_attn_probs, return_attn_probs,
block_table, block_table,
out=out,
) )
...@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache( ...@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
rotary_interleaved=True, rotary_interleaved=True,
alibi_slopes=None, alibi_slopes=None,
num_splits=0, num_splits=0,
*,
out=None,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache( ...@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
cache_batch_idx, cache_batch_idx,
block_table, block_table,
alibi_slopes, alibi_slopes,
None, out,
softmax_scale, softmax_scale,
causal, causal,
window_size[0], window_size[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment