Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1940cdec
"docs/vscode:/vscode.git/clone" did not exist on "63f7796a72f9afbcba64f9bf0df11753ecc4558d"
Unverified
Commit
1940cdec
authored
May 09, 2025
by
Chang Su
Committed by
GitHub
May 09, 2025
Browse files
[Bugfix] Fix Llama4 gibberish output with long context and CUDA graph (#6162)
parent
63484f9f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
8 deletions
+125
-8
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+125
-8
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
1940cdec
...
...
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
local_attn_metadata
=
getattr
(
metadata
,
"local_attn_metadata"
,
None
)
use_local_attention
=
(
self
.
attention_chunk_size
is
not
None
and
local_attn_metadata
is
not
None
use_local_attn
=
(
self
.
attention_chunk_size
is
not
None
and
local_attn_metadata
is
not
None
and
(
hasattr
(
layer
,
"use_irope"
)
and
layer
.
use_irope
)
)
# We do cascade attention for Draft Decode with topk > 1
use_cascade_attn
=
self
.
topk
>
1
...
...
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
elif
use_local_att
entio
n
:
elif
use_local_attn
:
# Use chunked (local) attention batching for self-attention
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
...
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table
=
local_attn_metadata
.
local_block_table
,
cache_seqlens
=
local_attn_metadata
.
local_seqused_k
,
cu_seqlens_q
=
local_attn_metadata
.
local_query_start_loc
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
cu_seqlens_k_new
=
None
,
max_seqlen_q
=
local_attn_metadata
.
local_max_query_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
...
...
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
# This is being used by normal decode and draft decode when topk == 1
self
.
decode_cuda_graph_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
...
...
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
),
}
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if
self
.
attention_chunk_size
is
not
None
:
# Estimate maximum sizes for local attention metadata
max_seq_len
=
self
.
max_context_len
page_size
=
self
.
page_size
or
1
attn_chunk_size
=
self
.
attention_chunk_size
max_virtual_batches
=
max_bs
*
(
(
max_seq_len
+
attn_chunk_size
-
1
)
//
attn_chunk_size
)
max_blocks_per_seq
=
(
max_seq_len
+
attn_chunk_size
-
1
)
//
attn_chunk_size
max_pages_per_block
=
(
attn_chunk_size
+
page_size
-
1
)
//
page_size
self
.
decode_cuda_graph_local_attn_metadata
=
{
"local_query_start_loc"
:
torch
.
zeros
(
max_virtual_batches
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"local_seqused_k"
:
torch
.
zeros
(
max_virtual_batches
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"local_block_table"
:
torch
.
zeros
(
max_virtual_batches
,
max_blocks_per_seq
*
max_pages_per_block
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
}
# This is used by draft decode's first half of metadata when topk > 1
if
self
.
topk
>
1
:
self
.
draft_decode_metadata_topk_normal
=
{
...
...
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
if
self
.
attention_chunk_size
is
not
None
:
metadata
.
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_query_start_loc"
],
local_seqused_k
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_seqused_k"
],
local_block_table
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_block_table"
],
local_max_query_len
=
1
,
local_max_seq_len
=
1
,
)
elif
forward_mode
.
is_target_verify
():
if
self
.
topk
<=
1
:
metadata
.
cache_seqlens_int32
=
self
.
target_verify_metadata
[
...
...
@@ -1572,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand
.
page_table
[:
cache_loc
.
shape
[
0
]].
copy_
(
cache_loc
[:,
:
decode_length
].
contiguous
().
to
(
torch
.
int32
)
)
# TODO: we need to test this part for llama 4 eagle case
self
.
_init_local_attn_metadata
(
metadata
,
device
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else
:
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
# Normal Decode
...
...
@@ -1599,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
self
.
_
init
_local_attn_metadata
(
metadata
,
device
)
self
.
_
update
_local_attn_metadata
_for_replay
(
metadata
,
bs
)
elif
forward_mode
.
is_target_verify
():
if
self
.
topk
<=
1
:
metadata
=
self
.
target_verify_metadata
[
bs
]
...
...
@@ -1755,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table
,
self
.
page_size
,
)
local_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
),
local_seqused_k
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
),
...
...
@@ -1764,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
)
metadata
.
local_attn_metadata
=
local_metadata
def
_update_local_attn_metadata_for_replay
(
self
,
metadata
:
FlashAttentionMetadata
,
bs
:
int
):
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
if
self
.
attention_chunk_size
is
None
:
return
# Access preallocated buffers
local_q_buf
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_query_start_loc"
]
local_k_buf
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_seqused_k"
]
local_block_buf
=
self
.
decode_cuda_graph_local_attn_metadata
[
"local_block_table"
]
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
]
# Create a modified version for local attention that only processes the last token
# This mimics the normal decode pattern
cu_seqlens_q
=
torch
.
arange
(
bs
+
1
,
device
=
cu_seqlens_q
.
device
,
dtype
=
cu_seqlens_q
.
dtype
)
seqlens
=
metadata
.
cache_seqlens_int32
[:
bs
]
# Slice the page_table to match the batch size and actual sequence length
# This serves three important purposes:
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
# 3. Prevents zeros in the block table which can cause garbage output during replay
#
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len
=
int
(
seqlens
.
max
().
item
())
sliced_page_table
=
metadata
.
page_table
[:
bs
,
:
max_seq_len
]
cu_seqlens_q_np
=
cu_seqlens_q
.
cpu
().
numpy
()
seqlens_np
=
seqlens
.
cpu
().
numpy
()
(
seqlens_q_local_np
,
cu_seqlens_q_local_np
,
seqlens_k_local_np
,
block_table_local
,
)
=
make_local_attention_virtual_batches
(
self
.
attention_chunk_size
,
cu_seqlens_q_np
,
seqlens_np
,
sliced_page_table
,
self
.
page_size
,
)
# Convert back to tensors
device
=
local_q_buf
.
device
cu_seqlens_q_local
=
torch
.
from_numpy
(
cu_seqlens_q_local_np
).
to
(
device
)
seqlens_k_local
=
torch
.
from_numpy
(
seqlens_k_local_np
).
to
(
device
)
block_table_local
=
block_table_local
.
to
(
device
)
# Get sizes
q_len
=
cu_seqlens_q_local
.
shape
[
0
]
k_len
=
seqlens_k_local
.
shape
[
0
]
b0
,
b1
=
block_table_local
.
shape
# In-place updates into preallocated tensors and zero out the unused space
local_q_buf
[:
q_len
].
copy_
(
cu_seqlens_q_local
)
local_q_buf
[
q_len
:].
fill_
(
0
)
local_k_buf
[:
k_len
].
copy_
(
seqlens_k_local
)
local_k_buf
[
k_len
:].
fill_
(
0
)
local_block_buf
[:
b0
,
:
b1
].
copy_
(
block_table_local
)
local_block_buf
[
b0
:,
:].
fill_
(
0
)
local_block_buf
[:
b0
,
b1
:].
fill_
(
0
)
if
metadata
.
local_attn_metadata
is
not
None
:
lam
=
metadata
.
local_attn_metadata
lam
.
local_max_query_len
=
int
(
seqlens_q_local_np
.
max
())
lam
.
local_max_seq_len
=
int
(
seqlens_k_local_np
.
max
())
class
FlashAttentionMultiStepBackend
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment