Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d5fa68b
"vscode:/vscode.git/clone" did not exist on "e6a151e7f2a3b7e72b52c30d5cea0045b272e44f"
Unverified
Commit
9d5fa68b
authored
Jun 08, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 08, 2025
Browse files
Use torch.compile to fuse flash attention decode metadata preparation (#6973)
parent
2c186425
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
18 deletions
+31
-18
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+31
-18
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
9d5fa68b
...
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
)
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else
:
else
:
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
# Normal Decode
# Normal Decode
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
max_len
=
seq_lens_cpu
.
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
max_seq_len_k
=
max_len
metadata
.
max_seq_len_k
=
max_len
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
normal_decode_set_medadata
(
# Optimize cumulative sequence length calculation
metadata
,
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
self
.
req_to_token
,
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
)
req_pool_indices
,
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
max_seq_pages
,
seq_lens
,
self
.
page_size
,
)
)
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:
],
]
page_indices
//=
self
.
page_size
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
self
.
_update_local_attn_metadata_for_replay
(
metadata
,
bs
)
self
.
_update_local_attn_metadata_for_replay
(
metadata
,
bs
)
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
if
self
.
topk
<=
1
:
if
self
.
topk
<=
1
:
...
@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
...
@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
out_cache_loc
=
forward_batch
.
out_cache_loc
,
out_cache_loc
=
forward_batch
.
out_cache_loc
,
)
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
normal_decode_set_medadata
(
metadata
,
req_to_token
,
req_pool_indices
,
strided_indices
,
max_seq_pages
,
seq_lens
,
page_size
,
):
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
))
page_indices
=
req_to_token
[
req_pool_indices
[:,
None
],
strided_indices
[:
max_seq_pages
][
None
,
:],
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
page_size
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment