Unverified Commit e7cfd7c5 authored by Fynn Schmitt-Ulms's avatar Fynn Schmitt-Ulms Committed by GitHub
Browse files

Add Gemma4 Eagle3 support (#39450)


Signed-off-by: default avatarRahul-Tuli <rtuli@redhat.com>
Signed-off-by: default avatarFynn Schmitt-Ulms <fschmitt@redhat.com>
Co-authored-by: default avatarRahul-Tuli <rtuli@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent e816a881
......@@ -825,6 +825,7 @@ class SpeculativeConfig:
"kimi_k2",
"kimi_k25",
"minimax_m2",
"gemma4",
]
if (
self.method in ("eagle3", "extract_hidden_states", "dflash")
......
......@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
extract_layer_index,
......@@ -838,7 +844,7 @@ class Gemma4CrossDecoderLayers(nn.Module):
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module):
class Gemma4Model(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config)
......@@ -1168,7 +1174,7 @@ class Gemma4Model(nn.Module):
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
......@@ -1204,6 +1210,7 @@ class Gemma4Model(nn.Module):
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
......@@ -1222,6 +1229,9 @@ class Gemma4Model(nn.Module):
per_layer_input=layer_per_input,
**kwargs,
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
......@@ -1236,6 +1246,9 @@ class Gemma4Model(nn.Module):
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
......@@ -1381,7 +1394,9 @@ class Gemma4Model(nn.Module):
return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
class Gemma4ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts, SupportsEagle3
):
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
......@@ -1463,7 +1478,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
......
......@@ -64,7 +64,12 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
......@@ -845,7 +850,12 @@ class Gemma4MultimodalEmbedder(nn.Module):
info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsEagle3,
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
......@@ -565,9 +565,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks:
computed.pop()
# Re-align after eagle pop: the pop may break the alignment
# when block_size != alignment_tokens (hybrid models with
# different page sizes, e.g. Gemma4).
while (
block_size != alignment_tokens
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
return computed_blocks
......
......@@ -1329,6 +1329,7 @@ class SpecDecodeBaseProposer:
"Qwen3_5MoeForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Gemma4ForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment