Unverified Commit d562aa63 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Sync with FA v2.6.0 to support soft capping (#13)

parent 12375706
...@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): ...@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward( def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
): ):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
...@@ -60,6 +60,7 @@ def _flash_attn_forward( ...@@ -60,6 +60,7 @@ def _flash_attn_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
return_softmax, return_softmax,
None, None,
) )
...@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward( ...@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table, block_table,
...@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward( ...@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
return_softmax, return_softmax,
None, None,
) )
...@@ -125,13 +128,19 @@ def _flash_attn_backward( ...@@ -125,13 +128,19 @@ def _flash_attn_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
rng_state=None, rng_state=None,
): ):
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( (
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
dout, dout,
q, q,
k, k,
...@@ -147,6 +156,7 @@ def _flash_attn_backward( ...@@ -147,6 +156,7 @@ def _flash_attn_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
deterministic, deterministic,
None, None,
rng_state, rng_state,
...@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward( ...@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
rng_state=None, rng_state=None,
): ):
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( (
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
dout, dout,
q, q,
k, k,
...@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward( ...@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
deterministic, deterministic,
None, None,
rng_state, rng_state,
...@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out, out=out,
...@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None, block_table=None,
...@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
...@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out, out=out,
...@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None, block_table=None,
...@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
...@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
out=out, out=out,
...@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
...@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=block_table, block_table=block_table,
...@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
...@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
...@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func( ...@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # <=0.0 means deactivate
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func( ...@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j. the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass, deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func( ...@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func( ...@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func( ...@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
...@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func( ...@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -796,6 +843,7 @@ def flash_attn_func( ...@@ -796,6 +843,7 @@ def flash_attn_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -858,6 +906,7 @@ def flash_attn_func( ...@@ -858,6 +906,7 @@ def flash_attn_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass, deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func(
(they might not have the right scaling). (they might not have the right scaling).
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor). normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
...@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func(
(they might not have the right scaling). (they might not have the right scaling).
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor). normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func( ...@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
...@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func( ...@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
...@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func( ...@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func(
(they might not have the right scaling). (they might not have the right scaling).
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor). normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func( ...@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
...@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache( ...@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True, rotary_interleaved=True,
alibi_slopes=None, alibi_slopes=None,
num_splits=0, num_splits=0,
return_softmax_lse=False,
*, *,
out=None, out=None,
): ):
...@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache( ...@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache( ...@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits. to automatically determine the number of splits.
Don't change this unless you know what you are doing. Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
""" """
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
...@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache( ...@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
) )
return out return (out, softmax_lse) if return_softmax_lse else out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment