Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
2dae104d
"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "649949807fd92dc2ca3a6c2c23c8cf2ac7383182"
Unverified
Commit
2dae104d
authored
Jun 10, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 10, 2025
Browse files
Minor cleanup of fa3 backend (#6999)
parent
cef6655b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
64 deletions
+63
-64
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+48
-48
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+15
-16
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
2dae104d
...
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
"cache_seqlens"
"cache_seqlens"
][:
bs
]
][:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
.
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
)
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
...
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
:
bs
:
bs
]
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
num_tokens_per_bs
=
num_tokens
//
bs
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
...
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
if
spec_info
is
not
None
:
if
spec_info
is
not
None
:
# Draft Decode
# Draft Decode
if
self
.
topk
<=
1
:
if
self
.
topk
<=
1
:
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
# When topk = 1, we use the normal decode metadata
# When topk = 1, we use the normal decode metadata
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
(
seq_lens
+
(
self
.
speculative_step_id
+
1
)).
to
(
torch
.
int32
)
max_len
=
seq_lens_cpu
.
max
().
item
()
)
metadata
.
max_seq_len_k
=
max_len
+
self
.
speculative_step_id
+
1
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][
:
max_seq_pages
],
]
page_indices
//=
self
.
page_size
normal_decode_set_medadata
(
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
cache_seqlens_int32
,
metadata
.
cu_seqlens_k
,
metadata
.
page_table
,
self
.
req_to_token
,
req_pool_indices
,
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
max_seq_pages
,
seq_lens
,
self
.
speculative_step_id
+
1
,
self
.
page_size
,
)
else
:
else
:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
# 1. The first half of metadata for prefix tokens
metadata
=
self
.
draft_decode_metadata_topk_normal
[
bs
]
metadata
=
self
.
draft_decode_metadata_topk_normal
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
# metadata.max_seq_len_q = self.topk, already set in capture
# metadata.max_seq_len_q = self.topk, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
...
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
self
.
speculative_num_steps
,
-
1
self
.
speculative_num_steps
,
-
1
).
T
.
contiguous
()
).
T
.
contiguous
()
metadata_expand
.
page_table
[:
cache_loc
.
shape
[
0
]].
copy_
(
metadata_expand
.
page_table
[:
cache_loc
.
shape
[
0
]].
copy_
(
cache_loc
[:,
:
decode_length
]
.
contiguous
().
to
(
torch
.
int32
)
cache_loc
[:,
:
decode_length
]
)
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else
:
else
:
...
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_k
=
max_len
metadata
.
max_seq_len_k
=
max_len
normal_decode_set_medadata
(
normal_decode_set_medadata
(
metadata
,
metadata
.
cache_seqlens_int32
,
metadata
.
cu_seqlens_k
,
metadata
.
page_table
,
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
max_seq_pages
,
max_seq_pages
,
seq_lens
,
seq_lens
,
0
,
self
.
page_size
,
self
.
page_size
,
)
)
...
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
if
self
.
topk
<=
1
:
if
self
.
topk
<=
1
:
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
.
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
)
metadata
.
max_seq_len_k
=
(
metadata
.
max_seq_len_k
=
(
...
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
# When topk > 1, we need two specific target verify metadata, and then merge states
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
# 1. The first half of metadata for prefix tokens
metadata
=
self
.
target_verify_metadata_topk_normal
[
bs
]
metadata
=
self
.
target_verify_metadata_topk_normal
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
...
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand
.
page_table
.
copy_
(
metadata_expand
.
page_table
.
copy_
(
non_masked_page_table
.
gather
(
1
,
sort_order
)
non_masked_page_table
.
gather
(
1
,
sort_order
)
)
)
metadata_expand
.
cache_seqlens_int32
.
copy_
(
metadata_expand
.
cache_seqlens_int32
.
copy_
(
mask
.
sum
(
dim
=
1
))
mask
.
sum
(
dim
=
1
).
to
(
torch
.
int32
)
)
metadata_expand
.
cu_seqlens_k
[
1
:].
copy_
(
metadata_expand
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata_expand
.
cache_seqlens_int32
,
metadata_expand
.
cache_seqlens_int32
,
...
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
)
)
elif
forward_mode
.
is_draft_extend
():
elif
forward_mode
.
is_draft_extend
():
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
)
accept_length
=
spec_info
.
accept_length
[:
bs
]
accept_length
=
spec_info
.
accept_length
[:
bs
]
metadata
.
max_seq_len_q
=
accept_length
.
max
().
item
()
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
)
+
1
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
)
...
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
req_pool_indices
[:,
None
],
req_pool_indices
[:,
None
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
]
page_indices
//=
self
.
page_size
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
if
encoder_lens
is
not
None
:
if
encoder_lens
is
not
None
:
# Only support encoder size 1 for now
# Only support encoder size 1 for now
...
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
...
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
# TODO: incrementally update the metadata for the later steps,
# so that they do not need to recompute everything from scratch.
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
...
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
)
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
# @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def
normal_decode_set_medadata
(
def
normal_decode_set_medadata
(
metadata
,
cache_seqlens_int32
:
torch
.
Tensor
,
req_to_token
,
cu_seqlens_k
:
torch
.
Tensor
,
req_pool_indices
,
page_table
:
torch
.
Tensor
,
strided_indices
,
req_to_token
:
torch
.
Tensor
,
max_seq_pages
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
,
strided_indices
:
torch
.
Tensor
,
page_size
,
max_seq_pages
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_len_delta
:
int
,
page_size
:
int
,
):
):
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
cache_seqlens_int32
.
copy_
(
seq_lens
+
seq_len_delta
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
seq
_
lens
,
dim
=
0
,
dtype
=
torch
.
int32
))
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_
seqlens
_int32
,
dim
=
0
,
dtype
=
torch
.
int32
))
page_indices
=
req_to_token
[
page_indices
=
req_to_token
[
req_pool_indices
[:,
None
],
req_pool_indices
[:,
None
],
strided_indices
[:
max_seq_pages
][
None
,
:],
strided_indices
[:
max_seq_pages
][
None
,
:],
]
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
page_size
)
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
page_size
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
2dae104d
...
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
...
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
self
.
_page_size
=
page_size
self
.
_page_size
=
page_size
self
.
_sm_scale
=
sm_scale
self
.
_sm_scale
=
sm_scale
with
self
.
device
as
device
:
try
:
try
:
# Standard version with just the required arguments (no use_profiler)
# Standard version with just the required arguments (no use_profiler)
self
.
_cached_module
.
plan
.
default
(
self
.
_cached_module
.
plan
.
default
(
self
.
_float_workspace_buffer
,
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
qo_indptr_cpu
,
qo_indptr_cpu
,
kv_indptr_cpu
,
kv_indptr_cpu
,
kv_len_arr_cpu
,
kv_len_arr_cpu
,
num_heads
,
num_heads
,
head_dim_ckv
,
head_dim_ckv
,
causal
,
causal
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in alternate MLA plan:
{
e
}
"
)
raise
RuntimeError
(
f
"Error in alternate MLA plan:
{
e
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment