Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c776234b
Unverified
Commit
c776234b
authored
Apr 17, 2025
by
Chang Su
Committed by
GitHub
Apr 17, 2025
Browse files
Enable local attention during decode (#5479)
parent
3bface15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
68 deletions
+113
-68
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+113
-68
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
c776234b
...
...
@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len
=
seq_lens_np
.
max
()
effective_chunk_size
=
min
(
attn_chunk_size
,
max_seq_len
)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size
=
(
effective_chunk_size
//
page_size
)
*
page_size
if
effective_chunk_size
<
page_size
:
effective_chunk_size
=
page_size
attn_chunk_size
=
effective_chunk_size
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
...
@@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
self
.
_init_local_attn_metadata
(
metadata
,
device
)
else
:
# Normal Decode
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
...
...
@@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
self
.
_init_local_attn_metadata
(
metadata
,
device
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
metadata
.
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
self
.
speculative_num_draft_tokens
...
...
@@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
(
self
.
attention_chunk_size
is
not
None
and
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
):
# Convert tensors to numpy for local attention processing
cu_seqlens_q_np
=
metadata
.
cu_seqlens_q
.
cpu
().
numpy
()
seq_lens_np
=
metadata
.
cache_seqlens_int32
.
cpu
().
numpy
()
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len
=
seq_lens_np
.
max
()
effective_chunk_size
=
min
(
self
.
attention_chunk_size
,
max_seq_len
)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size
=
(
effective_chunk_size
//
self
.
page_size
)
*
self
.
page_size
if
effective_chunk_size
<
self
.
page_size
:
effective_chunk_size
=
self
.
page_size
# Create local attention metadata
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local
,
)
=
make_local_attention_virtual_batches
(
effective_chunk_size
,
cu_seqlens_q_np
,
seq_lens_np
,
metadata
.
page_table
,
self
.
page_size
,
)
local_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
),
local_seqused_k
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
),
local_block_table
=
block_table_local
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
seqlens_k_local_np
.
max
(),
)
metadata
.
local_attn_metadata
=
local_metadata
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
self
.
_init_local_attn_metadata
(
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
...
...
@@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
local_attn_metadata
=
getattr
(
metadata
,
"local_attn_metadata"
,
None
)
use_local_attention
=
(
self
.
attention_chunk_size
is
not
None
and
local_attn_metadata
is
not
None
)
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
...
...
@@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
# Always use non-chunked logic for cross-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
metadata
.
encoder_page_table
,
cache_seqlens
=
metadata
.
encoder_lens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
encoder_cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
elif
use_local_attention
:
# Use chunked (local) attention batching for self-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
local_attn_metadata
.
local_block_table
,
cache_seqlens
=
local_attn_metadata
.
local_seqused_k
,
cu_seqlens_q
=
local_attn_metadata
.
local_query_start_loc
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
local_attn_metadata
.
local_max_query_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
# Default: single-token self-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
metadata
.
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
@@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
device
=
seq_lens
.
device
if
forward_mode
.
is_decode_or_idle
():
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
...
...
@@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
]
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
page_table
)
self
.
_init_local_attn_metadata
(
metadata
,
device
)
else
:
# Normal Decode
max_len
=
seq_lens_cpu
.
max
().
item
()
...
...
@@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
self
.
_init_local_attn_metadata
(
metadata
,
device
)
elif
forward_mode
.
is_target_verify
():
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
...
...
@@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph."""
return
0
def
_init_local_attn_metadata
(
self
,
metadata
:
FlashAttentionMetadata
,
device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if
self
.
attention_chunk_size
is
None
:
metadata
.
local_attn_metadata
=
None
return
cu_seqlens_q
=
metadata
.
cu_seqlens_q
cache_seqlens_int32
=
metadata
.
cache_seqlens_int32
page_table
=
metadata
.
page_table
if
cu_seqlens_q
is
None
or
cache_seqlens_int32
is
None
or
page_table
is
None
:
metadata
.
local_attn_metadata
=
None
return
cu_seqlens_q_np
=
cu_seqlens_q
.
cpu
().
numpy
()
seq_lens_np
=
cache_seqlens_int32
.
cpu
().
numpy
()
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local
,
)
=
make_local_attention_virtual_batches
(
self
.
attention_chunk_size
,
cu_seqlens_q_np
,
seq_lens_np
,
page_table
,
self
.
page_size
,
)
local_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
),
local_seqused_k
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
),
local_block_table
=
block_table_local
.
to
(
device
),
local_max_query_len
=
int
(
seqlens_q_local_np
.
max
()),
local_max_seq_len
=
int
(
seqlens_k_local_np
.
max
()),
)
metadata
.
local_attn_metadata
=
local_metadata
class
FlashAttentionMultiStepBackend
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment