Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6a9d6ca3
Unverified
Commit
6a9d6ca3
authored
Aug 17, 2025
by
zyksir
Committed by
GitHub
Aug 16, 2025
Browse files
fix unexcepted answer in EAGLE mode (#9252)
parent
94371dbb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
5 deletions
+33
-5
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+18
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+15
-0
No files found.
python/sglang/srt/speculative/eagle_utils.py
View file @
6a9d6ca3
...
@@ -177,11 +177,24 @@ class EagleDraftInput:
...
@@ -177,11 +177,24 @@ class EagleDraftInput:
)
)
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
None
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
None
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
if
has_been_filtered
:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if
len
(
new_indices
)
!=
len
(
self
.
topk_p
):
logger
.
warning
(
f
"length of new_indices:
{
len
(
new_indices
)
}
!= length of topk_p:
{
len
(
self
.
topk_p
)
}
, this should not happen"
)
self
.
topk_p
=
self
.
topk_p
[:
len
(
new_indices
)]
self
.
topk_p
=
self
.
topk_p
[:
len
(
new_indices
)]
self
.
topk_index
=
self
.
topk_index
[:
len
(
new_indices
)]
self
.
topk_index
=
self
.
topk_index
[:
len
(
new_indices
)]
self
.
hidden_states
=
self
.
hidden_states
[:
len
(
new_indices
)]
self
.
hidden_states
=
self
.
hidden_states
[:
len
(
new_indices
)]
self
.
verified_id
=
self
.
verified_id
[:
len
(
new_indices
)]
self
.
verified_id
=
self
.
verified_id
[:
len
(
new_indices
)]
else
:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self
.
topk_p
=
self
.
topk_p
[
new_indices
]
self
.
topk_index
=
self
.
topk_index
[
new_indices
]
self
.
hidden_states
=
self
.
hidden_states
[
new_indices
]
self
.
verified_id
=
self
.
verified_id
[
new_indices
]
def
merge_batch
(
self
,
spec_info
:
EagleDraftInput
):
def
merge_batch
(
self
,
spec_info
:
EagleDraftInput
):
if
self
.
hidden_states
is
None
:
if
self
.
hidden_states
is
None
:
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
6a9d6ca3
...
@@ -836,6 +836,21 @@ class EAGLEWorker(TpModelWorker):
...
@@ -836,6 +836,21 @@ class EAGLEWorker(TpModelWorker):
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
has_finished
,
unfinished_req_index
=
False
,
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
has_finished
=
True
else
:
unfinished_req_index
.
append
(
i
)
if
has_finished
:
unfinished_index_device
=
torch
.
tensor
(
unfinished_req_index
,
dtype
=
torch
.
int64
,
device
=
batch
.
spec_info
.
topk_p
.
device
,
)
batch
.
spec_info
.
filter_batch
(
unfinished_index_device
,
has_been_filtered
=
False
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
assert
isinstance
(
batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
batch
.
spec_info
,
EagleDraftInput
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment