"tests/vscode:/vscode.git/clone" did not exist on "e5db3e2774fd16394f8a96a608263ff2416385c8"
Unverified Commit 962d7038 authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[Bugfix][llama4_eagle] Fix missing 'lm_head' attribute (#29926)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
parent e23ca3a0
......@@ -402,7 +402,11 @@ def test_eagle_correctness(
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
......
......@@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM
from vllm.model_executor.models.utils import extract_layer_index
......@@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.config.vocab_size, scale=logit_scale
)
self.lm_head = ParallelLMHead(
self.config.draft_vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Set MoE hyperparameters
self.set_moe_parameters()
......@@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
skip_prefixes=([]),
)
loader.load_weights(map(transform, weights))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment