Unverified Commit e4155e96 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add flash_attn_varlen_func to sgl-kernel (#5315)

parent 1b1b47a9
...@@ -204,3 +204,75 @@ def flash_attn_with_kvcache( ...@@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
) )
# return (out, softmax_lse) if return_softmax_lse else out # return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
):
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k,
v,
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment