Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
205d5cb4
Unverified
Commit
205d5cb4
authored
May 17, 2025
by
Chang Su
Committed by
GitHub
May 17, 2025
Browse files
perf: Optimize local attention memory allocation in FlashAttentionBackend (#6356)
parent
3d7f7a43
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
13 deletions
+57
-13
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+57
-13
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
205d5cb4
...
...
@@ -1434,19 +1434,7 @@ class FlashAttentionBackend(AttentionBackend):
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
if
self
.
attention_chunk_size
is
not
None
:
metadata
.
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_query_start_loc"
],
local_seqused_k
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_seqused_k"
],
local_block_table
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_block_table"
],
local_max_query_len
=
1
,
local_max_seq_len
=
1
,
)
self
.
_update_local_attn_metadata_for_capture
(
metadata
,
batch_size
)
elif
forward_mode
.
is_target_verify
():
if
self
.
topk
<=
1
:
...
...
@@ -1807,6 +1795,62 @@ class FlashAttentionBackend(AttentionBackend):
)
metadata
.
local_attn_metadata
=
local_metadata
def
_update_local_attn_metadata_for_capture
(
self
,
metadata
:
FlashAttentionMetadata
,
bs
:
int
):
"""Update local attention metadata during CUDA graph capture phase.
This method calculates the exact buffer sizes needed for local attention metadata
during the CUDA graph capture phase, optimizing memory usage by creating views of
pre-allocated buffers with exactly the sizes needed.
"""
seq_lens_capture
=
metadata
.
cache_seqlens_int32
max_seq_len
=
int
(
seq_lens_capture
.
max
().
item
())
page_table_capture
=
metadata
.
page_table
cu_seqlens_q_np
=
metadata
.
cu_seqlens_q
.
cpu
().
numpy
()
seqlens_np
=
seq_lens_capture
.
cpu
().
numpy
()
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local_np
,
)
=
make_local_attention_virtual_batches
(
self
.
attention_chunk_size
,
cu_seqlens_q_np
,
seqlens_np
,
page_table_capture
,
self
.
page_size
,
)
# Get exact dimensions from the calculation
q_len
=
len
(
cu_seqlens_q_local_np
)
k_len
=
len
(
seqlens_k_local_np
)
b0
=
block_table_local_np
.
shape
[
0
]
if
block_table_local_np
.
shape
[
0
]
>
0
else
bs
b1
=
block_table_local_np
.
shape
[
1
]
if
block_table_local_np
.
shape
[
1
]
>
0
else
1
# Create views of the pre-allocated buffers with exactly these sizes
# This is the key optimization - we only use the memory we actually need
local_query_start_loc
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_query_start_loc"
][:
q_len
]
local_seqused_k
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_seqused_k"
][
:
k_len
]
local_block_table
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_block_table"
][:
b0
,
:
b1
]
metadata
.
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
local_block_table
,
local_max_query_len
=
1
,
local_max_seq_len
=
max_seq_len
,
)
def
_update_local_attn_metadata_for_replay
(
self
,
metadata
:
FlashAttentionMetadata
,
bs
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment