"vscode:/vscode.git/clone" did not exist on "391612e78b8fb6dc1322fc0db5ab3a5b3732463e"
Unverified Commit 05f6846e authored by Rahul Tuli's avatar Rahul Tuli Committed by GitHub
Browse files

Support llama3 eagle3 head with llama4 verifier (#25961)


Signed-off-by: default avatarrahul-tuli <rtuli@redhat.com>
Signed-off-by: default avatarRahul Tuli <rtuli@redhat.com>
parent 20db99cc
......@@ -604,6 +604,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
......@@ -241,7 +242,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
is_multimodal: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
......
......@@ -64,7 +64,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
......@@ -717,7 +722,9 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
......@@ -767,6 +774,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers()
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Llama4ImagePatchInputs]:
......
......@@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
model to use as auxiliary inputs for the Eagle3 drafter. These layers
provide intermediate hidden states that help the drafter make better
predictions. This is the standard field used in Eagle3 checkpoints.
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
......@@ -28,3 +32,7 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"
]
......@@ -2943,15 +2943,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers()
)
else:
if not supports_eagle3(self.model):
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested"
)
# Try to get auxiliary layers from speculative config,
# otherwise use model's default layers
aux_layers = self._get_eagle3_aux_layers_from_config()
if aux_layers:
logger.info(
"Using auxiliary layers from speculative config: %s",
aux_layers,
)
else:
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
......@@ -3006,6 +3015,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
)
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
"""Extract Eagle3 auxiliary layer indices from speculative config.
These indices specify which hidden states from the base model should
be used as auxiliary inputs for the Eagle3 drafter model during
speculative decoding.
Returns:
Tuple of layer indices if found in draft model config,
None otherwise.
"""
if not (self.speculative_config and self.speculative_config.draft_model_config):
return None
hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment