Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2dae104d
Unverified
Commit
2dae104d
authored
Jun 10, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 10, 2025
Browse files
Minor cleanup of fa3 backend (#6999)
parent
cef6655b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
64 deletions
+63
-64
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+48
-48
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+15
-16
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
2dae104d
...
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
"cache_seqlens"
"cache_seqlens"
][:
bs
]
][:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
.
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
)
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
...
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
:
bs
:
bs
]
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
num_tokens_per_bs
=
num_tokens
//
bs
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
...
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
if
spec_info
is
not
None
:
if
spec_info
is
not
None
:
# Draft Decode
# Draft Decode
if
self
.
topk
<=
1
:
if
self
.
topk
<=
1
:
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
# When topk = 1, we use the normal decode metadata
# When topk = 1, we use the normal decode metadata
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
(
seq_lens
+
(
self
.
speculative_step_id
+
1
)).
to
(
torch
.
int32
)
max_len
=
seq_lens_cpu
.
max
().
item
()
)
metadata
.
max_seq_len_k
=
max_len
+
self
.
speculative_step_id
+
1
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][
:
max_seq_pages
],
]
page_indices
//=
self
.
page_size
normal_decode_set_medadata
(
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
metadata
.
cache_seqlens_int32
,
metadata
.
cu_seqlens_k
,
metadata
.
page_table
,
self
.
req_to_token
,
req_pool_indices
,
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
max_seq_pages
,
seq_lens
,
self
.
speculative_step_id
+
1
,
self
.
page_size
,
)
else
:
else
:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
# 1. The first half of metadata for prefix tokens
metadata
=
self
.
draft_decode_metadata_topk_normal
[
bs
]
metadata
=
self
.
draft_decode_metadata_topk_normal
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
# metadata.max_seq_len_q = self.topk, already set in capture
# metadata.max_seq_len_q = self.topk, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
...
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
self
.
speculative_num_steps
,
-
1
self
.
speculative_num_steps
,
-
1
).
T
.
contiguous
()
).
T
.
contiguous
()
metadata_expand
.
page_table
[:
cache_loc
.
shape
[
0
]].
copy_
(
metadata_expand
.
page_table
[:
cache_loc
.
shape
[
0
]].
copy_
(
cache_loc
[:,
:
decode_length
]
.
contiguous
().
to
(
torch
.
int32
)
cache_loc
[:,
:
decode_length
]
)
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else
:
else
:
...
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_k
=
max_len
metadata
.
max_seq_len_k
=
max_len
normal_decode_set_medadata
(
normal_decode_set_medadata
(
metadata
,
metadata
.
cache_seqlens_int32
,
metadata
.
cu_seqlens_k
,
metadata
.
page_table
,
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
],
max_seq_pages
,
max_seq_pages
,
seq_lens
,
seq_lens
,
0
,
self
.
page_size
,
self
.
page_size
,
)
)
...
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
if
self
.
topk
<=
1
:
if
self
.
topk
<=
1
:
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
.
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
)
metadata
.
max_seq_len_k
=
(
metadata
.
max_seq_len_k
=
(
...
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
# When topk > 1, we need two specific target verify metadata, and then merge states
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
# 1. The first half of metadata for prefix tokens
metadata
=
self
.
target_verify_metadata_topk_normal
[
bs
]
metadata
=
self
.
target_verify_metadata_topk_normal
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
...
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand
.
page_table
.
copy_
(
metadata_expand
.
page_table
.
copy_
(
non_masked_page_table
.
gather
(
1
,
sort_order
)
non_masked_page_table
.
gather
(
1
,
sort_order
)
)
)
metadata_expand
.
cache_seqlens_int32
.
copy_
(
metadata_expand
.
cache_seqlens_int32
.
copy_
(
mask
.
sum
(
dim
=
1
))
mask
.
sum
(
dim
=
1
).
to
(
torch
.
int32
)
)
metadata_expand
.
cu_seqlens_k
[
1
:].
copy_
(
metadata_expand
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata_expand
.
cache_seqlens_int32
,
metadata_expand
.
cache_seqlens_int32
,
...
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
)
)
elif
forward_mode
.
is_draft_extend
():
elif
forward_mode
.
is_draft_extend
():
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
.
to
(
torch
.
int32
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
)
accept_length
=
spec_info
.
accept_length
[:
bs
]
accept_length
=
spec_info
.
accept_length
[:
bs
]
metadata
.
max_seq_len_q
=
accept_length
.
max
().
item
()
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
)
+
1
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
)
...
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
req_pool_indices
[:,
None
],
req_pool_indices
[:,
None
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
]
page_indices
//=
self
.
page_size
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
if
encoder_lens
is
not
None
:
if
encoder_lens
is
not
None
:
# Only support encoder size 1 for now
# Only support encoder size 1 for now
...
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
...
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
# TODO: incrementally update the metadata for the later steps,
# so that they do not need to recompute everything from scratch.
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
...
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
)
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
# @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def
normal_decode_set_medadata
(
def
normal_decode_set_medadata
(
metadata
,
cache_seqlens_int32
:
torch
.
Tensor
,
req_to_token
,
cu_seqlens_k
:
torch
.
Tensor
,
req_pool_indices
,
page_table
:
torch
.
Tensor
,
strided_indices
,
req_to_token
:
torch
.
Tensor
,
max_seq_pages
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
,
strided_indices
:
torch
.
Tensor
,
page_size
,
max_seq_pages
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_len_delta
:
int
,
page_size
:
int
,
):
):
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
cache_seqlens_int32
.
copy_
(
seq_lens
+
seq_len_delta
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
seq
_
lens
,
dim
=
0
,
dtype
=
torch
.
int32
))
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_
seqlens
_int32
,
dim
=
0
,
dtype
=
torch
.
int32
))
page_indices
=
req_to_token
[
page_indices
=
req_to_token
[
req_pool_indices
[:,
None
],
req_pool_indices
[:,
None
],
strided_indices
[:
max_seq_pages
][
None
,
:],
strided_indices
[:
max_seq_pages
][
None
,
:],
]
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
page_size
)
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
page_size
)
metadata
.
page_table
[:,
max_seq_pages
:].
fill_
(
0
)
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
2dae104d
...
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
...
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
self
.
_page_size
=
page_size
self
.
_page_size
=
page_size
self
.
_sm_scale
=
sm_scale
self
.
_sm_scale
=
sm_scale
with
self
.
device
as
device
:
try
:
try
:
# Standard version with just the required arguments (no use_profiler)
# Standard version with just the required arguments (no use_profiler)
self
.
_cached_module
.
plan
.
default
(
self
.
_cached_module
.
plan
.
default
(
self
.
_float_workspace_buffer
,
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
qo_indptr_cpu
,
qo_indptr_cpu
,
kv_indptr_cpu
,
kv_indptr_cpu
,
kv_len_arr_cpu
,
kv_len_arr_cpu
,
num_heads
,
num_heads
,
head_dim_ckv
,
head_dim_ckv
,
causal
,
causal
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in alternate MLA plan:
{
e
}
"
)
raise
RuntimeError
(
f
"Error in alternate MLA plan:
{
e
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment