Unverified Commit 30be1884 authored by drbh's avatar drbh Committed by GitHub
Browse files

Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)

* Fix: don't apply post layernorm in SiglipVisionTransformer

This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813).

This also makes Siglip consistent with the existing Clip implementation:

https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613



* fix: adjust pali gemma for post layer norm and small refactors

---------
Co-authored-by: default avatarTravis Addair <tgaddair@gmail.com>
parent f3c5d7d9
...@@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
config=config.vision_config, config=config.vision_config,
weights=weights, weights=weights,
) )
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load( self.multi_modal_projector = TensorParallelColumnLinear.load(
config, config,
...@@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None: if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values) image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state) last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens # mask where image or padding tokens
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
......
...@@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): ...@@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer( hidden_states, _ = encoder_layer(
...@@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): ...@@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward( def forward(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
): ):
r"""
Returns:
"""
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
...@@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): ...@@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
) )
last_hidden_state = encoder_outputs last_hidden_state = encoder_outputs
post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=post_last_hidden_state, last_hidden_state=last_hidden_state,
# pooler_output=pooled_output, # pooler_output=pooled_output,
# hidden_states=encoder_outputs, # hidden_states=encoder_outputs,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment