Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e4155e96
Unverified
Commit
e4155e96
authored
Apr 11, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 11, 2025
Browse files
Add flash_attn_varlen_func to sgl-kernel (#5315)
parent
1b1b47a9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
0 deletions
+72
-0
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+72
-0
No files found.
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
e4155e96
...
...
@@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
)
# return (out, softmax_lse) if return_softmax_lse else out
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
seqused_q
=
None
,
seqused_k
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
):
if
not
is_fa3_supported
():
raise
NotImplementedError
(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
if
softmax_scale
is
None
:
softmax_scale
=
(
q
.
shape
[
-
1
]
+
(
qv
.
shape
[
-
1
]
if
qv
is
not
None
else
0
))
**
(
-
0.5
)
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
q
,
k
,
v
,
None
,
# k_new
None
,
# v_new
qv
,
# qv
None
,
# out
cu_seqlens_q
,
cu_seqlens_k
,
None
,
# cu_seqlens_k_new
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
None
,
# page_table,
None
,
# kv_batch_idx
None
,
# leftpad_k
None
,
# rotary cos
None
,
# rotary sin
None
,
# seqlens_rotary
q_descale
,
k_descale
,
v_descale
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
is_rotary_interleaved
=
False
,
scheduler_metadata
=
None
,
num_splits
=
num_splits
,
pack_gqa
=
pack_gqa
,
sm_margin
=
sm_margin
,
)
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment