Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9da1095d
Unverified
Commit
9da1095d
authored
May 18, 2025
by
wwl2755
Committed by
GitHub
May 18, 2025
Browse files
[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)
Signed-off-by:
wwl2755
<
wangwenlong2755@gmail.com
>
parent
d1211f87
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
3 deletions
+21
-3
tests/spec_decode/e2e/test_eagle_correctness.py
tests/spec_decode/e2e/test_eagle_correctness.py
+0
-2
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+11
-0
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+8
-1
vllm/sequence.py
vllm/sequence.py
+2
-0
No files found.
tests/spec_decode/e2e/test_eagle_correctness.py
View file @
9da1095d
...
@@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
...
@@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
batch_size
,
output_len
,
seed
)
batch_size
,
output_len
,
seed
)
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
@
pytest
.
mark
.
skip
(
reason
=
"RE-ENABLE: Failing on main."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
vllm/model_executor/models/eagle.py
View file @
9da1095d
...
@@ -146,6 +146,17 @@ class EAGLE(nn.Module):
...
@@ -146,6 +146,17 @@ class EAGLE(nn.Module):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
# Handle both empty previous_hidden_states
# and mismatched batch size
batch_size
=
inputs_embeds
.
size
(
0
)
if
previous_hidden_states
.
size
(
0
)
==
0
or
\
previous_hidden_states
.
size
(
0
)
!=
batch_size
:
hidden_dim
=
self
.
config
.
model
.
hidden_size
device
=
inputs_embeds
.
device
# Create zero tensor with matching batch size
previous_hidden_states
=
\
torch
.
zeros
(
batch_size
,
hidden_dim
,
device
=
device
)
if
self
.
add_para_norm
:
if
self
.
add_para_norm
:
inputs_embeds
=
torch
.
cat
([
inputs_embeds
=
torch
.
cat
([
self
.
enorm
(
inputs_embeds
),
self
.
enorm
(
inputs_embeds
),
...
...
vllm/model_executor/models/medusa.py
View file @
9da1095d
...
@@ -164,7 +164,14 @@ class Medusa(nn.Module):
...
@@ -164,7 +164,14 @@ class Medusa(nn.Module):
self
,
self
,
previous_hidden_states
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
list
[
SamplerOutput
]:
)
->
Optional
[
list
[
SamplerOutput
]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if
previous_hidden_states
.
size
(
0
)
==
0
:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return
None
return
self
.
sample
(
return
self
.
sample
(
logits
=
self
.
compute_logits
(
logits
=
self
.
compute_logits
(
hidden_states
=
self
.
forward
(
previous_hidden_states
),
hidden_states
=
self
.
forward
(
previous_hidden_states
),
...
...
vllm/sequence.py
View file @
9da1095d
...
@@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True,
# may be "paused" then "resumed" later. This should only prune sequences
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
# which are confirmed to be aborted.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
# Only keep sequence IDs that exist in self._seq_ids
seq_ids
=
[
seq_id
for
seq_id
in
seq_ids
if
seq_id
in
self
.
_seq_ids
]
if
seq_ids
!=
self
.
_seq_ids
:
if
seq_ids
!=
self
.
_seq_ids
:
# Batch contents changed - prune removed sequences.
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
index
=
[
self
.
_seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment