Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e1360005
Unverified
Commit
e1360005
authored
Apr 28, 2025
by
Ekagra Ranjan
Committed by
GitHub
Apr 29, 2025
Browse files
[V1][Spec Decode] Make Eagle model arch config driven (#17323)
parent
86d9fc29
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
13 deletions
+26
-13
vllm/config.py
vllm/config.py
+2
-1
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+18
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+6
-11
No files found.
vllm/config.py
View file @
e1360005
...
...
@@ -2401,7 +2401,8 @@ class SpeculativeConfig:
pass
else
:
eagle_config
=
EAGLEConfig
(
self
.
draft_model_config
.
hf_config
)
self
.
draft_model_config
.
hf_config
,
method
=
self
.
method
)
self
.
draft_model_config
.
hf_config
=
eagle_config
if
(
self
.
num_speculative_tokens
is
not
None
...
...
vllm/transformers_utils/configs/eagle.py
View file @
e1360005
...
...
@@ -15,6 +15,7 @@ class EAGLEConfig(PretrainedConfig):
def
__init__
(
self
,
model
:
Union
[
PretrainedConfig
,
dict
,
None
]
=
None
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
method
:
Optional
[
str
]
=
'eagle'
,
**
kwargs
):
model_config
:
Union
[
PretrainedConfig
,
DeepseekV2Config
,
None
]
...
...
@@ -45,7 +46,23 @@ class EAGLEConfig(PretrainedConfig):
if
not
envs
.
VLLM_USE_V1
:
kwargs
[
"architectures"
]
=
[
"EAGLEModel"
]
else
:
kwargs
[
"architectures"
]
=
[
"EagleLlamaForCausalLM"
]
# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
if
method
==
"eagle"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle"
kwargs
[
"architectures"
]
=
[
f
"Eagle
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
elif
method
==
"eagle3"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle3"
kwargs
[
"architectures"
]
=
[
f
"Eagle3
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
Supported methods are eagle and eagle3."
)
super
().
__init__
(
**
kwargs
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
e1360005
...
...
@@ -9,8 +9,7 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.llama_eagle
import
EagleLlamaForCausalLM
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -225,13 +224,9 @@ class EagleProposer:
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
self
.
vllm_config
):
if
self
.
vllm_config
.
speculative_config
.
method
==
"eagle"
:
self
.
model
=
EagleLlamaForCausalLM
(
model_config
=
draft_model_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
else
:
assert
self
.
vllm_config
.
speculative_config
.
method
==
"eagle3"
self
.
model
=
Eagle3LlamaForCausalLM
(
draft_model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
draft_model_config
.
architectures
)
self
.
model
=
draft_model_cls
(
model_config
=
draft_model_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment