Unverified Commit d562aa63 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Sync with FA v2.6.0 to support soft capping (#13)

parent 12375706
......@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
......@@ -60,6 +60,7 @@ def _flash_attn_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
......@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
return_softmax,
block_table,
......@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
......@@ -125,13 +128,19 @@ def _flash_attn_backward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
(
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
dout,
q,
k,
......@@ -147,6 +156,7 @@ def _flash_attn_backward(
causal,
window_size[0],
window_size[1],
softcap,
deterministic,
None,
rng_state,
......@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
(
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
dout,
q,
k,
......@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward(
causal,
window_size[0],
window_size[1],
softcap,
deterministic,
None,
rng_state,
......@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
......@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
......@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
......@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
......@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
......@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
......@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
......@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
......@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
......@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
......@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
......@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(
......@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # <=0.0 means deactivate
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
......@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -796,6 +843,7 @@ def flash_attn_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -858,6 +906,7 @@ def flash_attn_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
......@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
......@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
......@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
......@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
*,
out=None,
):
......@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
......@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
......@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache(
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
return out
return (out, softmax_lse) if return_softmax_lse else out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment