Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
25e7dbe8
Unverified
Commit
25e7dbe8
authored
Oct 02, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 02, 2025
Browse files
Fix ngram spec with page size > 1 (#11135)
parent
0b2aa8a7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
10 deletions
+32
-10
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-1
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+2
-0
python/sglang/srt/speculative/ngram_utils.py
python/sglang/srt/speculative/ngram_utils.py
+8
-1
test/srt/test_ngram_speculative_decoding.py
test/srt/test_ngram_speculative_decoding.py
+15
-5
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
25e7dbe8
...
@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
seq_lens_cpu
_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
)
seq_lens_cpu
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
...
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor
,
prefix_lens_tensor
,
prefix_lens_cpu_tensor
,
prefix_lens_cpu_tensor
,
seq_lens_tensor
,
seq_lens_tensor
,
seq_lens_cpu
_tensor
,
seq_lens_cpu
,
last_loc
,
last_loc
,
extend_num_tokens
,
extend_num_tokens
,
)
)
...
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
self
.
input_embeds
=
(
...
...
python/sglang/srt/server_args.py
View file @
25e7dbe8
...
@@ -1087,7 +1087,10 @@ class ServerArgs:
...
@@ -1087,7 +1087,10 @@ class ServerArgs:
and
self
.
attention_backend
!=
"flashinfer"
and
self
.
attention_backend
!=
"flashinfer"
):
):
raise
ValueError
(
raise
ValueError
(
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
f
"speculative_eagle_topk(
{
self
.
speculative_eagle_topk
}
) > 1 "
f
"with page_size(
{
self
.
page_size
}
) > 1 is unstable "
"and produces incorrect results for paged attention backends. "
"This combination is only supported for the 'flashinfer' backend."
)
)
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
# TODO: support dp attention for ngram speculative decoding
# TODO: support dp attention for ngram speculative decoding
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
25e7dbe8
...
@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
...
@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
evict_mask
[
accept_index
]
=
False
accept_length_cpu
=
accept_length
.
cpu
()
accept_length_cpu
=
accept_length
.
cpu
()
# FIXME: this `tolist()` fixes the numerical calculation consistency
# try to unify the tensor representation and list representation
accept_length_list
=
accept_length_cpu
.
tolist
()
accept_length_list
=
accept_length_cpu
.
tolist
()
if
page_size
==
1
:
if
page_size
==
1
:
...
...
python/sglang/srt/speculative/ngram_utils.py
View file @
25e7dbe8
...
@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
...
@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
else
:
else
:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens
=
batch
.
seq_lens
prefix_lens
=
batch
.
seq_lens
prefix_lens_cpu
=
batch
.
seq_lens_cpu
end_offset
=
prefix_lens
+
self
.
draft_token_num
end_offset
=
prefix_lens
+
self
.
draft_token_num
end_offset_cpu
=
prefix_lens_cpu
+
self
.
draft_token_num
last_loc
=
get_last_loc
(
last_loc
=
get_last_loc
(
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
prefix_lens
,
prefix_lens
,
)
)
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
prefix_lens
,
end_offset
,
last_loc
,
len
(
batch
.
input_ids
)
prefix_lens
,
prefix_lens_cpu
,
end_offset
,
end_offset_cpu
,
last_loc
,
len
(
batch
.
input_ids
),
)
)
self
.
last_loc
=
last_loc
self
.
last_loc
=
last_loc
...
...
test/srt/test_ngram_speculative_decoding.py
View file @
25e7dbe8
...
@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
...
@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
]
]
class
Test
Standalone
SpeculativeDecodingBase
(
CustomTestCase
):
class
Test
Ngram
SpeculativeDecodingBase
(
CustomTestCase
):
model
=
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
model
=
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
...
@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
...
@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
self
.
assertGreater
(
avg_spec_accept_length
,
self
.
spec_decode_threshold
)
self
.
assertGreater
(
avg_spec_accept_length
,
self
.
spec_decode_threshold
)
class
Test
Standalone
SpeculativeDecodingTriton
(
Test
Standalone
SpeculativeDecodingBase
):
class
Test
Ngram
SpeculativeDecodingTriton
(
Test
Ngram
SpeculativeDecodingBase
):
@
classmethod
@
classmethod
def
get_server_args
(
cls
):
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"triton"
]
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"triton"
]
class
TestStandaloneSpeculativeDecodingFlashinfer
(
class
TestNgramSpeculativeDecodingFlashinfer
(
TestNgramSpeculativeDecodingBase
):
TestStandaloneSpeculativeDecodingBase
):
@
classmethod
@
classmethod
def
get_server_args
(
cls
):
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"flashinfer"
]
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"flashinfer"
]
class
TestNgramSpeculativeDecodingPaged
(
TestNgramSpeculativeDecodingBase
):
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"flashinfer"
,
"--page-size"
,
"64"
,
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment