Commit ee642f87 authored by Fynn Schmitt-Ulms's avatar Fynn Schmitt-Ulms Committed by khluu
Browse files

Add Gemma4 Eagle3 support (#39450)


Signed-off-by: default avatarRahul-Tuli <rtuli@redhat.com>
Signed-off-by: default avatarFynn Schmitt-Ulms <fschmitt@redhat.com>
Co-authored-by: default avatarRahul-Tuli <rtuli@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
(cherry picked from commit e7cfd7c5)
parent 6db56c09
...@@ -805,6 +805,8 @@ class SpeculativeConfig: ...@@ -805,6 +805,8 @@ class SpeculativeConfig:
"deepseek_v3", "deepseek_v3",
"kimi_k2", "kimi_k2",
"kimi_k25", "kimi_k25",
"minimax_m2",
"gemma4",
] ]
if ( if (
self.method in ("eagle3", "extract_hidden_states") self.method in ("eagle3", "extract_hidden_states")
......
...@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
extract_layer_index, extract_layer_index,
...@@ -838,7 +844,7 @@ class Gemma4CrossDecoderLayers(nn.Module): ...@@ -838,7 +844,7 @@ class Gemma4CrossDecoderLayers(nn.Module):
@support_torch_compile( @support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
) )
class Gemma4Model(nn.Module): class Gemma4Model(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config) config = _get_text_config(vllm_config.model_config.hf_config)
...@@ -1168,7 +1174,7 @@ class Gemma4Model(nn.Module): ...@@ -1168,7 +1174,7 @@ class Gemma4Model(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
**kwargs, **kwargs,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if self.fast_prefill_enabled: if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward( hidden_states = self.fast_prefill_forward(
input_ids, input_ids,
...@@ -1204,6 +1210,7 @@ class Gemma4Model(nn.Module): ...@@ -1204,6 +1210,7 @@ class Gemma4Model(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs") per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate( for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
...@@ -1222,6 +1229,9 @@ class Gemma4Model(nn.Module): ...@@ -1222,6 +1229,9 @@ class Gemma4Model(nn.Module):
per_layer_input=layer_per_input, per_layer_input=layer_per_input,
**kwargs, **kwargs,
) )
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
{ {
...@@ -1236,6 +1246,9 @@ class Gemma4Model(nn.Module): ...@@ -1236,6 +1246,9 @@ class Gemma4Model(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
...@@ -1381,7 +1394,9 @@ class Gemma4Model(nn.Module): ...@@ -1381,7 +1394,9 @@ class Gemma4Model(nn.Module):
return loaded_params return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): class Gemma4ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts, SupportsEagle3
):
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding # Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use # attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing. # separate q_proj + k_proj without packing.
...@@ -1463,7 +1478,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -1463,7 +1478,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs, **kwargs,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
) )
......
...@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import ( ...@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
...@@ -848,7 +853,12 @@ class Gemma4MultimodalEmbedder(nn.Module): ...@@ -848,7 +853,12 @@ class Gemma4MultimodalEmbedder(nn.Module):
info=Gemma4ProcessingInfo, info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder, dummy_inputs=Gemma4DummyInputsBuilder,
) )
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsEagle3,
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -565,9 +565,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager): ...@@ -565,9 +565,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
if use_eagle and computed_blocks[0]: if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, ( for computed in computed_blocks:
"aligned_length is not compatible with eagle now" computed.pop()
) # Re-align after eagle pop: the pop may break the alignment
# when block_size != alignment_tokens (hybrid models with
# different page sizes, e.g. Gemma4).
while (
block_size != alignment_tokens
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
return computed_blocks return computed_blocks
......
...@@ -1282,6 +1282,7 @@ class SpecDecodeBaseProposer: ...@@ -1282,6 +1282,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"Gemma4ForConditionalGeneration",
"HunYuanVLForConditionalGeneration", "HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration", "GlmOcrForConditionalGeneration",
"Qwen3_5ForConditionalGeneration", "Qwen3_5ForConditionalGeneration",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment