Unverified Commit 65986db6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Make Gemma and Gemma 2 accept `inputs_embeds` like Gemma 3 (#36787)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9556af87
...@@ -11,6 +11,8 @@ from unittest.mock import Mock ...@@ -11,6 +11,8 @@ from unittest.mock import Mock
import pytest import pytest
import torch import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm import LLM from vllm import LLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -91,6 +93,15 @@ def test_models( ...@@ -91,6 +93,15 @@ def test_models(
if enable_prompt_embeds: if enable_prompt_embeds:
with torch.no_grad(): with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
if model == "hmellor/tiny-random-Gemma2ForCausalLM" and (
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
):
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt embeddings
# are normalised in `get_prompt_embeddings`, like Gemma 3.
# For older versions, we need to manually normalise.
embed_scale = hf_model.config.hidden_size**0.5
normalizer = torch.tensor(embed_scale, dtype=prompt_embeds[0].dtype)
prompt_embeds = [p_e * normalizer for p_e in prompt_embeds]
with VllmRunner( with VllmRunner(
model, model,
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import pytest import pytest
import torch import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -151,6 +153,16 @@ def test_models( ...@@ -151,6 +153,16 @@ def test_models(
if prompt_embeds is not None: if prompt_embeds is not None:
embed = hf_model.model.get_input_embeddings()(token_ids) embed = hf_model.model.get_input_embeddings()(token_ids)
if "gemma" in model.lower() and (
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
):
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt
# embeddings are normalised in `get_prompt_embeddings`,
# like Gemma 3. For older versions, we need to manually normalise.
embed_scale = hf_model.config.hidden_size**0.5
normalizer = torch.tensor(embed_scale, dtype=embed.dtype)
embed *= normalizer
# MiniCPM models apply scale_emb to embeddings internally. # MiniCPM models apply scale_emb to embeddings internally.
# vLLM expects pre-scaled embeddings when using inputs_embeds. # vLLM expects pre-scaled embeddings when using inputs_embeds.
if model in EMBED_SCALING_MODELS: if model in EMBED_SCALING_MODELS:
......
...@@ -293,7 +293,7 @@ class GemmaModel(nn.Module): ...@@ -293,7 +293,7 @@ class GemmaModel(nn.Module):
) )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids) * self.normalizer
def forward( def forward(
self, self,
...@@ -307,7 +307,6 @@ class GemmaModel(nn.Module): ...@@ -307,7 +307,6 @@ class GemmaModel(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.embed_input_ids(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer
residual = None residual = None
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
......
...@@ -284,7 +284,7 @@ class Gemma2Model(nn.Module): ...@@ -284,7 +284,7 @@ class Gemma2Model(nn.Module):
) )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids) * self.normalizer
def forward( def forward(
self, self,
...@@ -298,7 +298,6 @@ class Gemma2Model(nn.Module): ...@@ -298,7 +298,6 @@ class Gemma2Model(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.embed_input_ids(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment