Unverified Commit 30be1884 authored by drbh's avatar drbh Committed by GitHub
Browse files

Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)

* Fix: don't apply post layernorm in SiglipVisionTransformer

This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813).

This also makes Siglip consistent with the existing Clip implementation:

https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613



* fix: adjust pali gemma for post layer norm and small refactors

---------
Co-authored-by: default avatarTravis Addair <tgaddair@gmail.com>
parent f3c5d7d9
......@@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
config=config.vision_config,
weights=weights,
)
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
......@@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
......
......@@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer(
......@@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
......@@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling(
last_hidden_state=post_last_hidden_state,
last_hidden_state=last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment