Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7a913301
"torchvision/transforms/v2/_container.py" did not exist on "ae831144800fab4b4f9f3f7c02d690f63af77ac2"
Unverified
Commit
7a913301
authored
Aug 03, 2025
by
Cheng Wan
Committed by
GitHub
Aug 03, 2025
Browse files
Save cuda graph memory for fa3 (#8567)
parent
5ce5093b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
11 deletions
+7
-11
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+7
-11
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
7a913301
...
...
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
)
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
][
req_pool_indice
s
,
:]
][
:
b
s
,
:]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
else
:
# When top k > 1, we need two specific draft decode metadata, and then merge states
...
...
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
][:
bs
+
1
]
metadata
.
page_table
=
self
.
draft_decode_metadata_topk_normal
[
"page_table"
][
req_pool_indice
s
,
:]
][
:
b
s
,
:]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand
.
cache_seqlens_int32
=
(
...
...
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][
req_pool_indice
s
,
:
:
b
s
,
:
]
# Precompute cumulative sequence lengths
metadata
.
cu_seqlens_q
=
torch
.
arange
(
...
...
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
:
(
bs
+
1
)
]
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][
req_pool_indices
,
:
]
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][:
bs
,
:]
self
.
target_verify_metadata
[
bs
]
=
metadata
else
:
...
...
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
][:
bs
+
1
]
metadata
.
page_table
=
self
.
target_verify_metadata_topk_normal
[
"page_table"
][
req_pool_indice
s
,
:]
][
:
b
s
,
:]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand
.
cache_seqlens_int32
=
(
...
...
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
cu_seqlens_k
=
self
.
draft_extend_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
metadata
.
page_table
=
self
.
draft_extend_metadata
[
"page_table"
][
req_pool_indices
,
:
]
metadata
.
page_table
=
self
.
draft_extend_metadata
[
"page_table"
][:
bs
,
:]
self
.
draft_extend_metadata
[
bs
]
=
metadata
...
...
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
][:
(
encoder_bs
+
1
)]
metadata
.
encoder_page_table
=
self
.
encoder_metadata
[
"encoder_page_table"
][
req_pool_indice
s
,
:
:
b
s
,
:
]
self
.
forward_metadata
=
metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment