Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
962d7038
Unverified
Commit
962d7038
authored
Dec 05, 2025
by
Divakar Verma
Committed by
GitHub
Dec 05, 2025
Browse files
[Bugfix][llama4_eagle] Fix missing 'lm_head' attribute (#29926)
Signed-off-by:
Divakar Verma
<
divakar.verma@amd.com
>
parent
e23ca3a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
3 deletions
+16
-3
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+5
-1
vllm/model_executor/models/llama4_eagle.py
vllm/model_executor/models/llama4_eagle.py
+11
-2
No files found.
tests/v1/e2e/test_spec_decode.py
View file @
962d7038
...
@@ -402,7 +402,11 @@ def test_eagle_correctness(
...
@@ -402,7 +402,11 @@ def test_eagle_correctness(
# Scout requires default backend selection
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
# pass if not ROCm
if
current_platform
.
is_rocm
():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest
.
skip
(
"Flex Attn for spec_decode not supported on ROCm currently"
)
else
:
else
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
...
...
vllm/model_executor/models/llama4_eagle.py
View file @
962d7038
...
@@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
...
@@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.torchao
import
TorchAOConfig
from
vllm.model_executor.layers.quantization.torchao
import
TorchAOConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama4
import
Llama4DecoderLayer
,
Llama4ForCausalLM
from
vllm.model_executor.models.llama4
import
Llama4DecoderLayer
,
Llama4ForCausalLM
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
...
@@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
...
@@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self
.
config
.
vocab_size
,
scale
=
logit_scale
self
.
config
.
vocab_size
,
scale
=
logit_scale
)
)
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
draft_vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
# Set MoE hyperparameters
# Set MoE hyperparameters
self
.
set_moe_parameters
()
self
.
set_moe_parameters
()
...
@@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
...
@@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
loader
=
AutoWeightsLoader
(
loader
=
AutoWeightsLoader
(
self
,
self
,
# lm_head is tied with target model (Llama4ForCausalLM)
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes
=
([
"lm_head."
]),
skip_prefixes
=
([]),
)
)
loader
.
load_weights
(
map
(
transform
,
weights
))
loader
.
load_weights
(
map
(
transform
,
weights
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment