Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c68698b3
Unverified
Commit
c68698b3
authored
Jun 12, 2025
by
qizixi
Committed by
GitHub
Jun 12, 2025
Browse files
[Bugfix] Fix EAGLE vocab embedding for multimodal target model (#19570)
Signed-off-by:
qizixi
<
qizixi@meta.com
>
parent
e3b12667
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
7 deletions
+12
-7
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+12
-7
No files found.
vllm/v1/spec_decode/eagle.py
View file @
c68698b3
...
...
@@ -329,16 +329,24 @@ class EagleProposer:
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
if
supports_multimodal
(
target_model
):
# handle multimodality
self
.
model
.
config
.
image_token_index
=
(
target_model
.
config
.
image_token_index
)
target_language_model
=
target_model
.
get_language_model
()
else
:
target_language_model
=
target_model
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
\
and
self
.
model
.
model
.
embed_tokens
.
weight
.
shape
\
==
target_model
.
model
.
embed_tokens
.
weight
.
shape
:
==
target_
language_
model
.
model
.
embed_tokens
.
weight
.
shape
:
logger
.
info
(
"Assuming the EAGLE head shares the same vocab embedding"
\
" with the target model."
)
del
self
.
model
.
model
.
embed_tokens
self
.
model
.
model
.
embed_tokens
=
target_model
.
model
.
embed_tokens
self
.
model
.
model
.
embed_tokens
=
(
target_language_model
.
model
.
embed_tokens
)
else
:
logger
.
info
(
"The EAGLE head's vocab embedding will be loaded separately"
\
...
...
@@ -349,12 +357,9 @@ class EagleProposer:
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if
self
.
vllm_config
.
speculative_config
.
method
!=
"eagle3"
and
\
hasattr
(
target_model
,
"lm_head"
):
hasattr
(
target_
language_
model
,
"lm_head"
):
logger
.
info
(
"Loading EAGLE LM head weights from the target model."
)
if
supports_multimodal
(
target_model
):
self
.
model
.
lm_head
=
target_model
.
get_language_model
().
lm_head
else
:
self
.
model
.
lm_head
=
target_model
.
lm_head
self
.
model
.
lm_head
=
target_language_model
.
lm_head
@
torch
.
inference_mode
()
def
dummy_run
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment