Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
145ac733
Unverified
Commit
145ac733
authored
Sep 29, 2025
by
Rahul Tuli
Committed by
GitHub
Sep 29, 2025
Browse files
[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883)
Signed-off-by:
Rahul Tuli
<
rtuli@redhat.com
>
parent
d0d138bc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
2 deletions
+22
-2
tests/speculative_decoding/speculators/test_eagle3.py
tests/speculative_decoding/speculators/test_eagle3.py
+3
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+6
-1
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+13
-1
No files found.
tests/speculative_decoding/speculators/test_eagle3.py
View file @
145ac733
...
@@ -14,6 +14,9 @@ from vllm.model_executor.models.interfaces import supports_eagle3
...
@@ -14,6 +14,9 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest
.
param
(
pytest
.
param
(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized"
,
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized"
,
id
=
"qwen3-eagle3-speculator"
),
id
=
"qwen3-eagle3-speculator"
),
pytest
.
param
(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16"
,
id
=
"qwen3-eagle3-speculator-w4a16-verifier"
),
])
])
def
test_eagle3_speculators_model
(
vllm_runner
,
example_prompts
,
model_path
,
def
test_eagle3_speculators_model
(
vllm_runner
,
example_prompts
,
model_path
,
monkeypatch
):
monkeypatch
):
...
...
vllm/model_executor/models/llama.py
View file @
145ac733
...
@@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module):
config
=
config
or
vllm_config
.
model_config
.
hf_config
config
=
config
or
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
self
.
get_quant_config
(
vllm_config
)
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
@@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module):
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
def
get_quant_config
(
self
,
vllm_config
:
VllmConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get quantization config for this layer. Override in subclasses."""
return
vllm_config
.
quant_config
@
support_torch_compile
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
145ac733
...
@@ -13,6 +13,8 @@ from vllm.logger import init_logger
...
@@ -13,6 +13,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
super
().
__init__
(
vllm_config
,
prefix
=
prefix
,
config
=
config
)
super
().
__init__
(
vllm_config
,
prefix
=
prefix
,
config
=
config
)
config
=
config
or
vllm_config
.
model_config
.
hf_config
config
=
config
or
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
self
.
get_quant_config
(
vllm_config
)
# override qkv
# override qkv
self
.
self_attn
.
qkv_proj
=
QKVParallelLinear
(
self
.
self_attn
.
qkv_proj
=
QKVParallelLinear
(
...
@@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
else
:
else
:
self
.
_residual_norm
=
self
.
_norm_after_residual
self
.
_residual_norm
=
self
.
_norm_after_residual
def
get_quant_config
(
self
,
vllm_config
:
VllmConfig
)
->
Optional
[
QuantizationConfig
]:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config
=
vllm_config
.
speculative_config
.
draft_model_config
draft_load_config
=
vllm_config
.
load_config
return
VllmConfig
.
get_quantization_config
(
draft_model_config
,
draft_load_config
)
if
draft_model_config
else
None
def
_norm_before_residual
(
def
_norm_before_residual
(
self
,
self
,
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment