Unverified Commit 20d37434 authored by Luciano Martins's avatar Luciano Martins Committed by GitHub
Browse files

[Bugfix] Gemma4: fix multimodal embedder norm order to match HF reference (#40411)


Signed-off-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
parent 18563f20
...@@ -849,22 +849,23 @@ class Gemma4MultimodalEmbedder(nn.Module): ...@@ -849,22 +849,23 @@ class Gemma4MultimodalEmbedder(nn.Module):
or multimodal_config.hidden_size or multimodal_config.hidden_size
) )
self.embedding_projection = ReplicatedLinear( self.embedding_pre_projection_norm = RMSNorm(
embedding_dim, embedding_dim,
self.text_hidden_size, eps=self.eps,
bias=False, has_weight=False,
) )
self.embedding_post_projection_norm = RMSNorm( self.embedding_projection = ReplicatedLinear(
embedding_dim,
self.text_hidden_size, self.text_hidden_size,
eps=self.eps, bias=False,
has_weight=False,
) )
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
"""Project soft tokens from a multimodal tower into LM space.""" """Project soft tokens from a multimodal tower into LM space."""
embs_proj, _ = self.embedding_projection(inputs_embeds) embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
return self.embedding_post_projection_norm(embs_proj) embs_proj, _ = self.embedding_projection(embs_normed)
return embs_proj
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment