Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dfbc1f88
Unverified
Commit
dfbc1f88
authored
Aug 01, 2025
by
Dipika Sikka
Committed by
GitHub
Aug 01, 2025
Browse files
[Speculative Decoding] Add `speculators` config support (#21345)
parent
87c94bc8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
232 additions
and
11 deletions
+232
-11
tests/speculative_decoding/speculators/test_eagle3.py
tests/speculative_decoding/speculators/test_eagle3.py
+16
-0
vllm/config.py
vllm/config.py
+16
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+21
-1
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+23
-3
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+29
-3
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/speculators/__init__.py
vllm/transformers_utils/configs/speculators/__init__.py
+2
-0
vllm/transformers_utils/configs/speculators/algos.py
vllm/transformers_utils/configs/speculators/algos.py
+32
-0
vllm/transformers_utils/configs/speculators/base.py
vllm/transformers_utils/configs/speculators/base.py
+91
-0
No files found.
tests/speculative_decoding/speculators/test_eagle3.py
0 → 100644
View file @
dfbc1f88
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"
),
(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized"
)])
def
test_llama
(
vllm_runner
,
example_prompts
,
model_path
):
with
vllm_runner
(
model_path
,
dtype
=
torch
.
bfloat16
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
=
20
)
print
(
vllm_outputs
)
assert
vllm_outputs
vllm/config.py
View file @
dfbc1f88
...
@@ -39,8 +39,8 @@ from vllm.transformers_utils.config import (
...
@@ -39,8 +39,8 @@ from vllm.transformers_utils.config import (
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
,
get_pooling_config
,
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
try_get_generation_config
,
try_get_safetensors_metadata
,
maybe_override_with_speculators_target_model
,
try_get_generation_config
,
try_get_tokenizer_config
,
uses_mrope
)
try_get_safetensors_metadata
,
try_get_tokenizer_config
,
uses_mrope
)
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -535,6 +535,15 @@ class ModelConfig:
...
@@ -535,6 +535,15 @@ class ModelConfig:
"affect the random state of the Python process that "
"affect the random state of the Python process that "
"launched vLLM."
,
self
.
seed
)
"launched vLLM."
,
self
.
seed
)
if
self
.
runner
!=
"draft"
:
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self
.
model
,
self
.
tokenizer
=
maybe_override_with_speculators_target_model
(
# noqa: E501
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
revision
=
self
.
revision
,
trust_remote_code
=
self
.
trust_remote_code
)
# Keep set served_model_name before maybe_model_redirect(self.model)
# Keep set served_model_name before maybe_model_redirect(self.model)
self
.
served_model_name
=
get_served_model_name
(
self
.
model
,
self
.
served_model_name
=
get_served_model_name
(
self
.
model
,
self
.
served_model_name
)
self
.
served_model_name
)
...
@@ -606,8 +615,8 @@ class ModelConfig:
...
@@ -606,8 +615,8 @@ class ModelConfig:
self
.
config_format
,
self
.
config_format
,
hf_overrides_kw
=
hf_overrides_kw
,
hf_overrides_kw
=
hf_overrides_kw
,
hf_overrides_fn
=
hf_overrides_fn
)
hf_overrides_fn
=
hf_overrides_fn
)
self
.
hf_config
=
hf_config
self
.
hf_config
=
hf_config
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
"attention_chunk_size"
,
None
)
"attention_chunk_size"
,
None
)
...
@@ -2980,10 +2989,13 @@ class SpeculativeConfig:
...
@@ -2980,10 +2989,13 @@ class SpeculativeConfig:
"Chunked prefill and EAGLE are not compatible "
"Chunked prefill and EAGLE are not compatible "
"when using V0."
)
"when using V0."
)
from
vllm.transformers_utils.configs
import
(
SpeculatorsConfig
)
from
vllm.transformers_utils.configs.eagle
import
(
from
vllm.transformers_utils.configs.eagle
import
(
EAGLEConfig
)
EAGLEConfig
)
if
isinstance
(
self
.
draft_model_config
.
hf_config
,
if
isinstance
(
self
.
draft_model_config
.
hf_config
,
EAGLEConfig
):
(
EAGLEConfig
,
SpeculatorsConfig
)
):
pass
pass
else
:
else
:
eagle_config
=
EAGLEConfig
(
eagle_config
=
EAGLEConfig
(
...
...
vllm/engine/arg_utils.py
View file @
dfbc1f88
...
@@ -978,7 +978,27 @@ class EngineArgs:
...
@@ -978,7 +978,27 @@ class EngineArgs:
provided as a JSON string input via CLI arguments or directly as a
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
dictionary from the engine.
"""
"""
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.configs.speculators.base
import
(
SpeculatorsConfig
)
if
self
.
speculative_config
is
None
:
if
self
.
speculative_config
is
None
:
hf_config
=
get_config
(
self
.
hf_config_path
or
self
.
model
,
self
.
trust_remote_code
,
self
.
revision
,
self
.
code_revision
,
self
.
config_format
)
# if loading a SpeculatorsConfig, load the specualtive_config
# details from the config directly
# no user input required / expected
if
isinstance
(
hf_config
,
SpeculatorsConfig
):
# We create one since we dont create one
self
.
speculative_config
=
{}
self
.
speculative_config
[
"num_speculative_tokens"
]
=
hf_config
.
num_lookahead_tokens
self
.
speculative_config
[
"model"
]
=
self
.
model
self
.
speculative_config
[
"method"
]
=
hf_config
.
method
else
:
return
None
return
None
# Note(Shangming): These parameters are not obtained from the cli arg
# Note(Shangming): These parameters are not obtained from the cli arg
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
dfbc1f88
...
@@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self
.
hidden_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hidden_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
getattr
(
config
,
"norm_before_residual"
,
False
):
self
.
_residual_norm
=
self
.
_norm_before_residual
else
:
self
.
_residual_norm
=
self
.
_norm_after_residual
def
_norm_before_residual
(
self
,
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
=
self
.
hidden_norm
(
hidden_states
)
residual
=
hidden_states
return
hidden_states
,
residual
def
_norm_after_residual
(
self
,
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
hidden_states
hidden_states
=
self
.
hidden_norm
(
hidden_states
)
return
hidden_states
,
residual
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual
=
hidden_states
embeds
=
self
.
input_layernorm
(
embeds
)
embeds
=
self
.
input_layernorm
(
embeds
)
hidden_states
=
self
.
hidden_norm
(
hidden_states
)
hidden_states
,
residual
=
self
.
_residual_norm
(
hidden_states
=
hidden_states
)
hidden_states
=
torch
.
cat
([
embeds
,
hidden_states
],
dim
=-
1
)
hidden_states
=
torch
.
cat
([
embeds
,
hidden_states
],
dim
=-
1
)
# Self Attention
# Self Attention
...
@@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
...
@@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
LlamaDecoderLayer
(
self
.
config
,
config
=
self
.
config
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
start_layer_id
}
"
),
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
start_layer_id
}
"
),
)
)
])
])
...
...
vllm/transformers_utils/config.py
View file @
dfbc1f88
...
@@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
...
@@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig
,
MLPSpeculatorConfig
,
MllamaConfig
,
MLPSpeculatorConfig
,
Nemotron_Nano_VL_Config
,
Nemotron_Nano_VL_Config
,
NemotronConfig
,
NVLM_D_Config
,
NemotronConfig
,
NVLM_D_Config
,
RWConfig
,
Step3TextConfig
,
RWConfig
,
SpeculatorsConfig
,
Step3VLConfig
,
UltravoxConfig
)
Step3TextConfig
,
Step3VLConfig
,
UltravoxConfig
)
# yapf: enable
# yapf: enable
from
vllm.transformers_utils.configs.mistral
import
adapt_config_dict
from
vllm.transformers_utils.configs.mistral
import
adapt_config_dict
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
...
@@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
...
@@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"mlp_speculator"
:
MLPSpeculatorConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"medusa"
:
MedusaConfig
,
"medusa"
:
MedusaConfig
,
"eagle"
:
EAGLEConfig
,
"eagle"
:
EAGLEConfig
,
"speculators"
:
SpeculatorsConfig
,
"nemotron"
:
NemotronConfig
,
"nemotron"
:
NemotronConfig
,
"NVLM_D"
:
NVLM_D_Config
,
"NVLM_D"
:
NVLM_D_Config
,
"ultravox"
:
UltravoxConfig
,
"ultravox"
:
UltravoxConfig
,
...
@@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
...
@@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return
config
return
config
def
maybe_override_with_speculators_target_model
(
model
:
str
,
tokenizer
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
)
->
tuple
[
str
,
str
]:
"""
If running a speculators config, override running model with target model
"""
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
,
token
=
_get_hf_token
(),
)
spec_config
=
config_dict
.
get
(
"speculators_config"
)
# Return the target model
if
spec_config
is
not
None
:
model
=
tokenizer
=
spec_config
[
"verifier"
][
"name_or_path"
]
return
model
,
tokenizer
def
get_config
(
def
get_config
(
model
:
Union
[
str
,
Path
],
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
...
@@ -345,9 +368,12 @@ def get_config(
...
@@ -345,9 +368,12 @@ def get_config(
token
=
_get_hf_token
(),
token
=
_get_hf_token
(),
**
kwargs
,
**
kwargs
,
)
)
# Use custom model class if it's in our registry
# Use custom model class if it's in our registry
model_type
=
config_dict
.
get
(
"model_type"
)
model_type
=
config_dict
.
get
(
"model_type"
)
if
model_type
is
None
:
model_type
=
"speculators"
if
config_dict
.
get
(
"speculators_config"
)
is
not
None
else
model_type
if
model_type
in
_CONFIG_REGISTRY
:
if
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config
=
config_class
.
from_pretrained
(
config
=
config_class
.
from_pretrained
(
...
...
vllm/transformers_utils/configs/__init__.py
View file @
dfbc1f88
...
@@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
...
@@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from
vllm.transformers_utils.configs.nemotron_h
import
NemotronHConfig
from
vllm.transformers_utils.configs.nemotron_h
import
NemotronHConfig
from
vllm.transformers_utils.configs.nemotron_vl
import
Nemotron_Nano_VL_Config
from
vllm.transformers_utils.configs.nemotron_vl
import
Nemotron_Nano_VL_Config
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.speculators.base
import
SpeculatorsConfig
from
vllm.transformers_utils.configs.step3_vl
import
(
Step3TextConfig
,
from
vllm.transformers_utils.configs.step3_vl
import
(
Step3TextConfig
,
Step3VisionEncoderConfig
,
Step3VisionEncoderConfig
,
Step3VLConfig
)
Step3VLConfig
)
...
@@ -44,6 +45,7 @@ __all__ = [
...
@@ -44,6 +45,7 @@ __all__ = [
"NemotronHConfig"
,
"NemotronHConfig"
,
"Nemotron_Nano_VL_Config"
,
"Nemotron_Nano_VL_Config"
,
"NVLM_D_Config"
,
"NVLM_D_Config"
,
"SpeculatorsConfig"
,
"UltravoxConfig"
,
"UltravoxConfig"
,
"Step3VLConfig"
,
"Step3VLConfig"
,
"Step3VisionEncoderConfig"
,
"Step3VisionEncoderConfig"
,
...
...
vllm/transformers_utils/configs/speculators/__init__.py
0 → 100644
View file @
dfbc1f88
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
vllm/transformers_utils/configs/speculators/algos.py
0 → 100644
View file @
dfbc1f88
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
SUPPORTED_SPECULATORS_TYPES
=
{}
def
register_speculator
(
name
):
def
decorator
(
fn
):
SUPPORTED_SPECULATORS_TYPES
[
name
]
=
fn
return
fn
return
decorator
@
register_speculator
(
"eagle3"
)
def
update_eagle3
(
config_dict
:
dict
,
vllm_config
:
dict
)
->
None
:
"""
Apply Eagle-3 specific configuration transformations.
Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
"""
vllm_config
[
"draft_vocab_size"
]
=
config_dict
.
get
(
"draft_vocab_size"
)
if
config_dict
.
get
(
"target_hidden_size"
)
is
not
None
:
vllm_config
[
"target_hidden_size"
]
=
config_dict
[
"target_hidden_size"
]
vllm_config
[
"norm_before_residual"
]
=
config_dict
.
get
(
"norm_before_residual"
,
True
)
vllm_config
[
"architectures"
]
=
[
"Eagle3LlamaForCausalLM"
]
vllm/transformers_utils/configs/speculators/base.py
0 → 100644
View file @
dfbc1f88
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
Any
,
Union
from
transformers
import
PretrainedConfig
from
vllm.transformers_utils.configs.speculators.algos
import
(
SUPPORTED_SPECULATORS_TYPES
)
__all__
=
[
"SpeculatorsConfig"
]
class
SpeculatorsConfig
(
PretrainedConfig
):
model_type
=
"speculators"
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
,
)
->
"SpeculatorsConfig"
:
"""Load speculators Eagle config and convert to vLLM format."""
config_dict
,
_
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
speculators_model_type
=
config_dict
.
get
(
"speculators_model_type"
)
if
speculators_model_type
not
in
SUPPORTED_SPECULATORS_TYPES
:
raise
ValueError
(
f
"Expected one of:
{
SUPPORTED_SPECULATORS_TYPES
}
. "
"Please ensure you're loading a speculators-format model."
)
# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls
.
validate_speculators_config
(
config_dict
=
config_dict
)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config
=
cls
.
convert_speculators_to_vllm
(
config_dict
=
config_dict
)
# Apply anything specific to the supported algorithm
algo_updater
=
SUPPORTED_SPECULATORS_TYPES
[
speculators_model_type
]
algo_updater
(
config_dict
=
config_dict
,
vllm_config
=
vllm_config
)
return
cls
(
**
vllm_config
)
@
classmethod
def
validate_speculators_config
(
cls
,
config_dict
:
dict
[
str
,
Any
])
->
None
:
try
:
spec_config
=
config_dict
[
"speculators_config"
]
methods
=
spec_config
[
"proposal_methods"
]
first_method
=
methods
[
0
]
_
=
first_method
[
"speculative_tokens"
]
_
=
spec_config
[
"verifier"
][
"name_or_path"
]
_
=
config_dict
[
"speculators_model_type"
]
except
(
KeyError
,
IndexError
,
TypeError
)
as
e
:
raise
ValueError
(
"Invalid speculators config structure"
)
from
e
if
"transformer_layer_config"
not
in
config_dict
:
raise
ValueError
(
"Must provide transformer_layer_config"
)
if
not
isinstance
(
config_dict
[
"transformer_layer_config"
],
dict
):
raise
TypeError
(
"'transformer_layer_config' must be a dictionary if provided"
)
@
classmethod
def
convert_speculators_to_vllm
(
cls
,
config_dict
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""
Convert speculators config format to vLLM format.
This method handles the translation of field names and structure
between speculators and vLLM formats.
Returns:
Dictionary with vLLM-compatible configuration
"""
# Currently we only support one proposal method
spec_config
=
config_dict
[
"speculators_config"
]
first_method
=
spec_config
.
get
(
"proposal_methods"
)[
0
]
num_lookahead_tokens
=
first_method
.
get
(
"speculative_tokens"
)
if
num_lookahead_tokens
is
None
:
raise
ValueError
(
"Missing 'speculative_tokens' in proposal method. "
f
"Got:
{
first_method
}
"
)
# Build base vLLM config
vllm_config
=
{
"method"
:
config_dict
.
get
(
"speculators_model_type"
),
"num_lookahead_tokens"
:
num_lookahead_tokens
,
"target_model"
:
spec_config
.
get
(
"verifier"
)[
"name_or_path"
]
}
vllm_config
.
update
(
config_dict
[
"transformer_layer_config"
])
return
vllm_config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment