Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c776234b
"docs/vscode:/vscode.git/clone" did not exist on "8bdb16ee9a658b33cc9def5d3062c63c9865163d"
Unverified
Commit
c776234b
authored
Apr 17, 2025
by
Chang Su
Committed by
GitHub
Apr 17, 2025
Browse files
Enable local attention during decode (#5479)
parent
3bface15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
68 deletions
+113
-68
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+113
-68
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
c776234b
...
...
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len
=
seq_lens_np
.
max
()
effective_chunk_size
=
min
(
attn_chunk_size
,
max_seq_len
)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size
=
(
effective_chunk_size
//
page_size
)
*
page_size
if
effective_chunk_size
<
page_size
:
effective_chunk_size
=
page_size
attn_chunk_size
=
effective_chunk_size
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
...
@@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
self
.
_init_local_attn_metadata
(
metadata
,
device
)
else
:
# Normal Decode
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
...
...
@@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
self
.
_init_local_attn_metadata
(
metadata
,
device
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
metadata
.
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
self
.
speculative_num_draft_tokens
...
...
@@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
(
self
.
attention_chunk_size
is
not
None
and
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
):
# Convert tensors to numpy for local attention processing
cu_seqlens_q_np
=
metadata
.
cu_seqlens_q
.
cpu
().
numpy
()
seq_lens_np
=
metadata
.
cache_seqlens_int32
.
cpu
().
numpy
()
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len
=
seq_lens_np
.
max
()
effective_chunk_size
=
min
(
self
.
attention_chunk_size
,
max_seq_len
)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size
=
(
effective_chunk_size
//
self
.
page_size
)
*
self
.
page_size
if
effective_chunk_size
<
self
.
page_size
:
effective_chunk_size
=
self
.
page_size
# Create local attention metadata
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local
,
)
=
make_local_attention_virtual_batches
(
effective_chunk_size
,
cu_seqlens_q_np
,
seq_lens_np
,
metadata
.
page_table
,
self
.
page_size
,
)
local_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
),
local_seqused_k
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
),
local_block_table
=
block_table_local
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
seqlens_k_local_np
.
max
(),
)
metadata
.
local_attn_metadata
=
local_metadata
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
self
.
_init_local_attn_metadata
(
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
...
...
@@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
local_attn_metadata
=
getattr
(
metadata
,
"local_attn_metadata"
,
None
)
use_local_attention
=
(
self
.
attention_chunk_size
is
not
None
and
local_attn_metadata
is
not
None
)
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
...
...
@@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
# Always use non-chunked logic for cross-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
metadata
.
encoder_page_table
,
cache_seqlens
=
metadata
.
encoder_lens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
encoder_cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
elif
use_local_attention
:
# Use chunked (local) attention batching for self-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
local_attn_metadata
.
local_block_table
,
cache_seqlens
=
local_attn_metadata
.
local_seqused_k
,
cu_seqlens_q
=
local_attn_metadata
.
local_query_start_loc
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
local_attn_metadata
.
local_max_query_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
# Default: single-token self-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
metadata
.
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
@@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
device
=
seq_lens
.
device
if
forward_mode
.
is_decode_or_idle
():
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
...
...
@@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
]
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
page_table
)
self
.
_init_local_attn_metadata
(
metadata
,
device
)
else
:
# Normal Decode
max_len
=
seq_lens_cpu
.
max
().
item
()
...
...
@@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
self
.
_init_local_attn_metadata
(
metadata
,
device
)
elif
forward_mode
.
is_target_verify
():
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
...
...
@@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph."""
return
0
def
_init_local_attn_metadata
(
self
,
metadata
:
FlashAttentionMetadata
,
device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if
self
.
attention_chunk_size
is
None
:
metadata
.
local_attn_metadata
=
None
return
cu_seqlens_q
=
metadata
.
cu_seqlens_q
cache_seqlens_int32
=
metadata
.
cache_seqlens_int32
page_table
=
metadata
.
page_table
if
cu_seqlens_q
is
None
or
cache_seqlens_int32
is
None
or
page_table
is
None
:
metadata
.
local_attn_metadata
=
None
return
cu_seqlens_q_np
=
cu_seqlens_q
.
cpu
().
numpy
()
seq_lens_np
=
cache_seqlens_int32
.
cpu
().
numpy
()
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local
,
)
=
make_local_attention_virtual_batches
(
self
.
attention_chunk_size
,
cu_seqlens_q_np
,
seq_lens_np
,
page_table
,
self
.
page_size
,
)
local_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
),
local_seqused_k
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
),
local_block_table
=
block_table_local
.
to
(
device
),
local_max_query_len
=
int
(
seqlens_q_local_np
.
max
()),
local_max_seq_len
=
int
(
seqlens_k_local_np
.
max
()),
)
metadata
.
local_attn_metadata
=
local_metadata
class
FlashAttentionMultiStepBackend
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment