Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a04720bc
Unverified
Commit
a04720bc
authored
May 22, 2025
by
Ekagra Ranjan
Committed by
GitHub
May 22, 2025
Browse files
[V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (#18290)
parent
7b9d832c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
3 deletions
+9
-3
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+4
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+5
-1
No files found.
vllm/transformers_utils/configs/eagle.py
View file @
a04720bc
...
@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
...
@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
assert
self
.
model
is
not
None
,
\
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle"
"model should not be None when method is eagle"
kwargs
[
"architectures"
]
=
[
kwargs
[
"architectures"
]
=
[
f
"Eagle
{
arch
}
"
for
arch
in
self
.
model
.
architectures
f
"Eagle
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
]
elif
method
==
"eagle3"
:
elif
method
==
"eagle3"
:
assert
self
.
model
is
not
None
,
\
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle3"
"model should not be None when method is eagle3"
kwargs
[
"architectures"
]
=
[
kwargs
[
"architectures"
]
=
[
f
"Eagle3
{
arch
}
"
for
arch
in
self
.
model
.
architectures
f
"Eagle3
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle3"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
]
else
:
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
...
...
vllm/v1/spec_decode/eagle.py
View file @
a04720bc
...
@@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import get_pp_group
...
@@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
(
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -308,6 +309,9 @@ class EagleProposer:
...
@@ -308,6 +309,9 @@ class EagleProposer:
loaded_weights
=
self
.
model
.
load_weights
(
loaded_weights
=
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
process_weights_after_loading
(
self
.
model
,
draft_model_config
,
target_device
)
# share embed_tokens with the target model if needed
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
:
if
get_pp_group
().
world_size
==
1
:
assert
"model.embed_tokens.weight"
not
in
loaded_weights
,
\
assert
"model.embed_tokens.weight"
not
in
loaded_weights
,
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment