Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b91cb67e
Unverified
Commit
b91cb67e
authored
Sep 18, 2025
by
Binyao Jiang
Committed by
GitHub
Sep 18, 2025
Browse files
[Performance] Qwen3-Next: replace arange to cached query_start_loc_li… (#10553)
parent
e7bc6003
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
31 deletions
+44
-31
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+44
-31
No files found.
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
b91cb67e
...
@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend):
self
.
forward_metadata
:
ForwardMetadata
=
None
self
.
forward_metadata
:
ForwardMetadata
=
None
self
.
state_indices_list
=
[]
self
.
state_indices_list
=
[]
self
.
query_start_loc_list
=
[]
self
.
query_start_loc_list
=
[]
self
.
cached_cuda_graph_decode_query_start_loc
:
torch
.
Tensor
=
None
@
classmethod
self
.
cached_cuda_graph_verify_query_start_loc
:
torch
.
Tensor
=
None
@
lru_cache
(
maxsize
=
128
)
def
_get_cached_arange
(
cls
,
bs
:
int
,
device_str
:
str
)
->
torch
.
Tensor
:
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
device
=
torch
.
device
(
device_str
)
return
torch
.
arange
(
0
,
bs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
query_start_loc
=
self
.
_get_cached_arange
(
bs
,
str
(
self
.
device
))
query_start_loc
=
torch
.
arange
(
0
,
bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
elif
forward_batch
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
if
forward_batch
.
forward_mode
.
is_target_verify
():
if
forward_batch
.
forward_mode
.
is_target_verify
():
query_start_loc
=
torch
.
arange
(
query_start_loc
=
torch
.
arange
(
...
@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend):
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
assert
(
max_num_tokens
%
max_bs
==
0
),
f
"max_num_tokens=
{
max_num_tokens
}
must be divisible by max_bs=
{
max_bs
}
"
verify_step
=
max_num_tokens
/
max_bs
for
i
in
range
(
max_bs
):
for
i
in
range
(
max_bs
):
self
.
state_indices_list
.
append
(
self
.
state_indices_list
.
append
(
torch
.
full
(
torch
.
full
(
...
@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend):
self
.
query_start_loc_list
.
append
(
self
.
query_start_loc_list
.
append
(
torch
.
empty
((
i
+
2
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
empty
((
i
+
2
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
)
self
.
cached_cuda_graph_decode_query_start_loc
=
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
cached_cuda_graph_verify_query_start_loc
=
torch
.
arange
(
0
,
max_bs
*
verify_step
+
1
,
step
=
verify_step
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
...
@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
_get_cached_arange
(
bs
,
"cuda"
))
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
cached_cuda_graph_decode_query_start_loc
[:
bs
+
1
]
)
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
torch
.
arange
(
self
.
cached_cuda_graph_verify_query_start_loc
[:
bs
+
1
]
0
,
bs
*
spec_info
.
draft_token_num
+
1
,
step
=
spec_info
.
draft_token_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
)
)
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
...
@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend):
mamba_indices
[
bs
-
num_padding
:]
=
-
1
mamba_indices
[
bs
-
num_padding
:]
=
-
1
self
.
state_indices_list
[
bs
-
1
][:
len
(
mamba_indices
)].
copy_
(
mamba_indices
)
self
.
state_indices_list
[
bs
-
1
][:
len
(
mamba_indices
)].
copy_
(
mamba_indices
)
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
_get_cached_arange
(
bs
,
"cuda"
))
if
num_padding
==
0
:
if
num_padding
>
0
:
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
query_start_loc_list
[
bs
-
1
][
bs
-
num_padding
:]
=
bs
-
num_padding
self
.
cached_cuda_graph_decode_query_start_loc
[:
bs
+
1
]
elif
forward_mode
.
is_target_verify
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
torch
.
arange
(
0
,
bs
*
spec_info
.
draft_token_num
+
1
,
step
=
spec_info
.
draft_token_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
)
)
else
:
if
num_padding
>
0
:
self
.
query_start_loc_list
[
bs
-
1
][:
bs
-
num_padding
].
copy_
(
self
.
query_start_loc_list
[
bs
-
1
][
bs
-
num_padding
:]
=
(
self
.
cached_cuda_graph_decode_query_start_loc
[:
bs
-
num_padding
]
)
self
.
query_start_loc_list
[
bs
-
1
][
bs
-
num_padding
:].
copy_
(
bs
-
num_padding
bs
-
num_padding
)
*
spec_info
.
draft_token_num
)
elif
forward_mode
.
is_target_verify
():
if
num_padding
==
0
:
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
cached_cuda_graph_verify_query_start_loc
[:
bs
+
1
]
)
else
:
self
.
query_start_loc_list
[
bs
-
1
][:
bs
-
num_padding
].
copy_
(
self
.
cached_cuda_graph_verify_query_start_loc
[:
bs
-
num_padding
]
)
self
.
query_start_loc_list
[
bs
-
1
][
bs
-
num_padding
:].
copy_
(
(
bs
-
num_padding
)
*
spec_info
.
draft_token_num
)
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment