Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cded039b
Unverified
Commit
cded039b
authored
Aug 21, 2025
by
Stefan He
Committed by
GitHub
Aug 21, 2025
Browse files
[FA3] Init Spec Page Table only when Spec is enabled to save ~40MB (#9455)
parent
275f9df3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
10 deletions
+13
-10
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+13
-10
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
cded039b
...
...
@@ -1163,6 +1163,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
max_num_pages
=
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
# This is being used by normal decode and draft decode when topk == 1
self
.
decode_cuda_graph_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
...
...
@@ -1174,13 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"page_table_draft_decode"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
...
...
@@ -1188,7 +1184,6 @@ class FlashAttentionBackend(AttentionBackend):
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if
self
.
attention_chunk_size
is
not
None
:
...
...
@@ -1274,6 +1269,14 @@ class FlashAttentionBackend(AttentionBackend):
self
.
speculative_num_draft_tokens
is
not
None
and
self
.
speculative_num_draft_tokens
>
0
):
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
]
=
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
target_verify_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
...
...
@@ -1290,7 +1293,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
...
...
@@ -1313,7 +1316,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"page_table"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment