Unverified Commit 1360801a authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Add PaliGemma (#30814)



* add new model like

* add state dict slicing + new model config

* update palma config and weights, passes vision activations

* fix

* update

* reorder loading/unpacking

* clean up

* add debug statements

* change device

* fix

* debugging

* fix noncausal mask

* fixup sdpa + causal mask

* fix activation function

* remove debug before changing modeling file

* add variants

* debug attention mask in generate

* revert to non-debug sdpa

* revert gemma modifications

* add custom language modeling

* use Processor

* add language modeling file to init

* try thin wrapper around generate

* Update

* update mask

* breakpoints galore

* remove conflict

* switch to left-padding

* add incomplete model doc

* add paligemma global files

* batch rename paligemma

* make generation match outputs and captioning

* style

* style

* remove copied from + doc

* remove more copied from

* remove copy from projector

* minor fix

* update config and style

* add readme - dummy

* CORRECT image captioning

* moving to args

* add siglip proper + fix merging image + text features

* take update_causal_mask from upstream

* remove breakpoint

* leverage AutoModel

* fix input_ids slicing

* make siglip head conditional

* remove encoder_decoder value

* remove unneeded modeling file

* add commented 4d attention mask

* FIXED generation with 4D mask

* Update src/transformers/models/siglip/modeling_siglip.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix left padding detection

* shuffle order of verifications

* fix missing labels for training

* fix

* vectorize merging of features, improve slicing

* improve testing before conversion

* handle merging in processor

* image token index depends on checkpoint

* add variants, save processor too

* save processors, base tokenizer off spm file

* expand model embeddings due to additional image token

* pass image processing args

* add convert rgb to siglip processor

* add \n token separately

* fix tokenizer and prompts

* fix docstrings

* change to camel

* fix casing

* debug pos_ids and sdpa

* pass and use cache_position

* add flag for newline tokenization

* Update src/transformers/models/paligemma/processing_paligemma.py
Co-authored-by: default avatarMerve Noyan <merveenoyan@gmail.com>

* simplify conversion script

* add copied from

* add precision to conversion script

* Update src/transformers/models/paligemma/modeling_paligemma.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* clean up

* Shift attention mask from `1:`

After discussion with @molbap

* add docs, fix quality

* quality, tied weights inheritance, and logits/label alignment

* fix more tests

* pass attn_implementation to language model correctly

* add SiglipVisionTransformer to no split modules

* skip paligemma test for sdpa dispatch to flash

* skip incompatible tests

* quality

* [broken archive maps]

* Apply suggestions

- remove archive lists
- style
- take shape of inputs_embeds for batch
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/utils/dummy_pt_objects.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* simplify conversion script

* add suggestions

* add suggestions

* add copied from

* fix

* move labels out

* revert

* fix

* remove placeholder labels if None

* use cache_position

* fix quality + docstrings

* fix quality

* fix paligemma 4d gemma mask incompatibility

* fix config docstring

* fix query and attn_mask dtype

---------
Co-authored-by: default avatarArthurZucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarMerve Noyan <merveenoyan@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent c96aca3a
...@@ -784,6 +784,8 @@ ...@@ -784,6 +784,8 @@
title: OWL-ViT title: OWL-ViT
- local: model_doc/owlv2 - local: model_doc/owlv2
title: OWLv2 title: OWLv2
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver - local: model_doc/perceiver
title: Perceiver title: Perceiver
- local: model_doc/pix2struct - local: model_doc/pix2struct
......
...@@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| [OPT](model_doc/opt) | ✅ | ✅ | ✅ | | [OPT](model_doc/opt) | ✅ | ✅ | ✅ |
| [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ | | [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ |
| [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ | | [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ |
| [PaliGemma](model_doc/paligemma) | ✅ | ❌ | ❌ |
| [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ | | [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ |
| [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ | | [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ |
| [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ | | [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# PaliGemma
## Overview
The PaliGemma model was proposed by Google. It is a 3B VLM composed by a Siglip-400m vision encoder and a Gemma-2B decoder linked by a multimodal linear projection. It is not a chat model with images. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models.
This model was contributed by [Molbap](https://huggingface.co/Molbap).
## PaliGemmaConfig
[[autodoc]] PaliGemmaConfig
## PaliGemmaProcessor
[[autodoc]] PaliGemmaProcessor
## PaliGemmaForConditionalGeneration
[[autodoc]] PaliGemmaForConditionalGeneration
- forward
...@@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
......
...@@ -582,6 +582,7 @@ _import_structure = { ...@@ -582,6 +582,7 @@ _import_structure = {
"OwlViTTextConfig", "OwlViTTextConfig",
"OwlViTVisionConfig", "OwlViTVisionConfig",
], ],
"models.paligemma": ["PaliGemmaConfig"],
"models.patchtsmixer": ["PatchTSMixerConfig"], "models.patchtsmixer": ["PatchTSMixerConfig"],
"models.patchtst": ["PatchTSTConfig"], "models.patchtst": ["PatchTSTConfig"],
"models.pegasus": [ "models.pegasus": [
...@@ -2651,6 +2652,13 @@ else: ...@@ -2651,6 +2652,13 @@ else:
"OwlViTVisionModel", "OwlViTVisionModel",
] ]
) )
_import_structure["models.paligemma"].extend(
[
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
"PaliGemmaProcessor",
]
)
_import_structure["models.patchtsmixer"].extend( _import_structure["models.patchtsmixer"].extend(
[ [
"PatchTSMixerForPrediction", "PatchTSMixerForPrediction",
...@@ -5126,6 +5134,9 @@ if TYPE_CHECKING: ...@@ -5126,6 +5134,9 @@ if TYPE_CHECKING:
OwlViTTextConfig, OwlViTTextConfig,
OwlViTVisionConfig, OwlViTVisionConfig,
) )
from .models.paligemma import (
PaliGemmaConfig,
)
from .models.patchtsmixer import ( from .models.patchtsmixer import (
PatchTSMixerConfig, PatchTSMixerConfig,
) )
...@@ -6956,6 +6967,11 @@ if TYPE_CHECKING: ...@@ -6956,6 +6967,11 @@ if TYPE_CHECKING:
OwlViTTextModel, OwlViTTextModel,
OwlViTVisionModel, OwlViTVisionModel,
) )
from .models.paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
PaliGemmaProcessor,
)
from .models.patchtsmixer import ( from .models.patchtsmixer import (
PatchTSMixerForPrediction, PatchTSMixerForPrediction,
PatchTSMixerForPretraining, PatchTSMixerForPretraining,
......
...@@ -173,6 +173,7 @@ from . import ( ...@@ -173,6 +173,7 @@ from . import (
opt, opt,
owlv2, owlv2,
owlvit, owlvit,
paligemma,
patchtsmixer, patchtsmixer,
patchtst, patchtst,
pegasus, pegasus,
......
...@@ -182,6 +182,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -182,6 +182,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("opt", "OPTConfig"), ("opt", "OPTConfig"),
("owlv2", "Owlv2Config"), ("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"), ("owlvit", "OwlViTConfig"),
("paligemma", "PaliGemmaConfig"),
("patchtsmixer", "PatchTSMixerConfig"), ("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"), ("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"), ("pegasus", "PegasusConfig"),
...@@ -464,6 +465,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -464,6 +465,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("opt", "OPT"), ("opt", "OPT"),
("owlv2", "OWLv2"), ("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"), ("owlvit", "OWL-ViT"),
("paligemma", "PaliGemma"),
("patchtsmixer", "PatchTSMixer"), ("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"), ("patchtst", "PatchTST"),
("pegasus", "Pegasus"), ("pegasus", "Pegasus"),
......
...@@ -93,6 +93,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -93,6 +93,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("oneformer", "OneFormerImageProcessor"), ("oneformer", "OneFormerImageProcessor"),
("owlv2", "Owlv2ImageProcessor"), ("owlv2", "Owlv2ImageProcessor"),
("owlvit", "OwlViTImageProcessor"), ("owlvit", "OwlViTImageProcessor"),
("paligemma", "CLIPImageProcessor"),
("perceiver", "PerceiverImageProcessor"), ("perceiver", "PerceiverImageProcessor"),
("pix2struct", "Pix2StructImageProcessor"), ("pix2struct", "Pix2StructImageProcessor"),
("poolformer", "PoolFormerImageProcessor"), ("poolformer", "PoolFormerImageProcessor"),
......
...@@ -313,6 +313,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -313,6 +313,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("nezha", "NezhaForPreTraining"), ("nezha", "NezhaForPreTraining"),
("nllb-moe", "NllbMoeForConditionalGeneration"), ("nllb-moe", "NllbMoeForConditionalGeneration"),
("openai-gpt", "OpenAIGPTLMHeadModel"), ("openai-gpt", "OpenAIGPTLMHeadModel"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("retribert", "RetriBertModel"), ("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"), ("roberta", "RobertaForMaskedLM"),
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
...@@ -697,6 +698,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( ...@@ -697,6 +698,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
......
...@@ -74,6 +74,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -74,6 +74,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("oneformer", "OneFormerProcessor"), ("oneformer", "OneFormerProcessor"),
("owlv2", "Owlv2Processor"), ("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"), ("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"), ("pix2struct", "Pix2StructProcessor"),
("pop2piano", "Pop2PianoProcessor"), ("pop2piano", "Pop2PianoProcessor"),
("sam", "SamProcessor"), ("sam", "SamProcessor"),
......
...@@ -331,6 +331,7 @@ else: ...@@ -331,6 +331,7 @@ else:
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
( (
"pegasus", "pegasus",
( (
......
...@@ -794,7 +794,6 @@ GEMMA_INPUTS_DOCSTRING = r""" ...@@ -794,7 +794,6 @@ GEMMA_INPUTS_DOCSTRING = r"""
"The bare Gemma Model outputting raw hidden-states without any specific head on top.", "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
class GemmaModel(GemmaPreTrainedModel): class GemmaModel(GemmaPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
...@@ -988,8 +987,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -988,8 +987,6 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and attention_mask.dim() == 4: if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask causal_mask = attention_mask
else: else:
causal_mask = torch.full( causal_mask = torch.full(
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_paligemma": ["PaliGemmaConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_paligemma"] = [
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
]
_import_structure["processing_paligemma"] = ["PaliGemmaProcessor"]
if TYPE_CHECKING:
from .configuration_paligemma import PaliGemmaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
)
from .processing_paligemma import PaliGemmaProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaliGemmamodel configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class PaliGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`PaliGemmaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 2048):
Dimension of the multimodal projection space.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden layer of the Language model.
Example:
```python
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a PaliGemma config
>>> text_config = GemmaConfig()
>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
>>> configuration = PaliGemmaConfig(vision_config, text_config)
>>> # Initializing a model from the paligemma-3b-224 style configuration
>>> model = PaliGemmaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "paligemma"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
intermediate_size=4096,
hidden_size=1152,
patch_size=14,
image_size=224,
num_hidden_layers=27,
num_attention_heads=16,
vocab_size=257152,
vision_use_head=False,
)
self.vocab_size = self.vocab_size
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.vocab_size = self.text_config.vocab_size
elif text_config is None:
self.text_config = CONFIG_MAPPING["gemma"](
hidden_size=2048,
num_hidden_layers=18,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1,
is_encoder_decoder=False,
)
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert PaliGemma checkpoints from the original repository.
"""
import argparse
import collections
import torch
from numpy import load
from transformers import (
AutoTokenizer,
GemmaTokenizer,
GemmaTokenizerFast,
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
PaliGemmaProcessor,
SiglipImageProcessor,
)
from transformers.tokenization_utils_base import AddedToken
from transformers.utils import logging
device = "cuda" # "cpu"
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# TODO add sequence length variations here
PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"]
def get_paligemma_config(variant: str, precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
if variant in PALIGEMMA_VARIANTS:
image_size = image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152 if variant != "2b-test" else 256000
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
else:
raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA_VARIANTS}")
return final_config
def slice_state_dict(state_dict, config):
# fmt: off
# patch embeddings
state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose(
3, 2, 0, 1
)
state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias")
# positional embeddings
state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale")
encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias")
encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale")
encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias")
encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel")
encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias")
encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel")
encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias")
encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel")
encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias")
encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel")
encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias")
encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel")
encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias")
encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel")
encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose()
state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias")
# multimodal projector
state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose()
state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias")
# text decoder (gemma)
embedding_vector = state_dict.pop("llm/embedder/input_embedding")
state_dict["language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w")
llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w")
llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w")
llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum")
llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale")
llm_post_attention_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale")
state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
return state_dict
def flatten_nested_dict(params, parent_key="", sep="/"):
items = []
for k, v in params.items():
k = k.removeprefix("params/")
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
@torch.no_grad()
def convert_paligemma_checkpoint(
checkpoint_path,
tokenizer_model_file,
pytorch_dump_folder_path,
variant: str,
precision: str,
do_convert_weights=False,
):
"""
Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed.
"""
config = get_paligemma_config(variant, precision=precision)
if do_convert_weights:
if variant == "2b-test":
# for the test model, the vocabulary was smaller
tokenizer_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
else:
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_file)
image_token = AddedToken("<image>", normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
# tokenizer.padding_side = 'right' # uncomment for testing purposes only.
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size}
image_processor.image_seq_length = config.vision_config.num_image_tokens
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
data = load(checkpoint_path)
state_dict = flatten_nested_dict(data)
del data
state_dict_transformers = slice_state_dict(state_dict, config)
del state_dict
model = PaliGemmaForConditionalGeneration(config).to(device).eval()
model.load_state_dict(state_dict_transformers)
del state_dict_transformers
else:
processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path)
model = (
PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa")
.to(device)
.eval()
)
model.config.text_config._attn_implementation = "sdpa"
# model expansion to get random embeds of image tokens
pad_shape = 64 # for performance reasons
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we resize the model
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[257152:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))),
dim=0,
)
model.save_pretrained(pytorch_dump_folder_path, max_shard_size="2GB", safe_serialization=True)
processor.save_pretrained(pytorch_dump_folder_path)
#
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path",
required=True,
type=str,
help="Path to the .npz checkpoint",
)
parser.add_argument(
"--tokenizer_model_file",
required=True,
type=str,
help="Path to the sentencepiece tokenizer.model file",
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
type=str,
help="Path to the output directory where model and processor will be saved.",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
parser.add_argument(
"--variant",
default="2b-test",
choices=PALIGEMMA_VARIANTS,
type=str,
help="String identifier of the paligemma variant to convert.",
)
parser.add_argument(
"--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights."
)
args = parser.parse_args()
convert_paligemma_checkpoint(
checkpoint_path=args.checkpoint_path,
tokenizer_model_file=args.tokenizer_model_file,
pytorch_dump_folder_path=args.pytorch_dump_folder_path,
variant=args.variant,
precision=args.precision,
do_convert_weights=args.do_convert_weights,
)
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PaliGemmamodel."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
from .configuration_paligemma import PaliGemmaConfig
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from ..auto import AutoModel, AutoModelForCausalLM
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PaliGemmaConfig"
@dataclass
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
"""
Base class for PaliGemmacausal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
hidden_states = self.linear(image_features)
return hidden_states
PALIGEMMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
PALIGEMMA_START_DOCSTRING,
)
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
# inference and fine-tuning
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
PALIGEMMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
[`SiglipImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""The PALIGEMMA model which consists of a vision backbone and a language model.""",
PALIGEMMA_START_DOCSTRING,
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
self._attn_implementation = config._attn_implementation
language_model = AutoModelForCausalLM.from_config(
config=config.text_config, attn_implementation=self._attn_implementation
)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
_, _, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
scaled_image_features = image_features / (self.config.hidden_size**0.5)
final_embedding = torch.zeros(
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
image_mask = input_ids == self.config.image_token_index
pad_mask = input_ids == self.pad_token_id
# expand masks to match embedding dimension
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
# insert padding and text token embeddings
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
# insert image embeddings - the image mask is always less or equal to the sentence in length
final_embedding = final_embedding.masked_scatter(
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
final_attention_mask_4d = final_attention_mask_4d.float().expand(
-1, self.config.text_config.num_key_value_heads, -1, -1
)
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
if labels is not None:
final_labels = torch.full(
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
else:
final_labels = None
return final_embedding, final_attention_mask_4d, final_labels, position_ids
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# the attention mask is turned 4d after, we keep track of the original one
input_attention_mask = attention_mask
if inputs_embeds is None:
# 1. Extra the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
else:
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
# TODO @molbap this will only work for dynamic cache.
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_seqlen = cache_position[-1] + 1
extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses PaliGemma+ Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
attention_mask = attention_mask.to(inputs_embeds.dtype)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs.logits
logits = logits.float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if input_attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask = input_attention_mask[..., 1:]
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
attention_mask=None,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
# here we need to recall past_length is num_image_tokens + previous input_ids.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"cache_position": cache_position,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
)
return model_inputs
def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for PaliGemma.
"""
import logging
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import AddedToken, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
logger = logging.getLogger(__name__)
IMAGE_TOKEN = "<image>"
# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
def _is_str_or_image(elem):
return isinstance(elem, (str)) or is_image_or_image_url(elem)
def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
build_string_from_input(
prompt="Prefix str"
bos_token="<s>",
image_seq_len=3,
image_token="<im>",
)
The output will be:
"<im><im><im><s>Initial str"
Args:
prompt (`List[Union[str, ImageInput]]`): The input prompt.
bos_token (`str`): The beginning of sentence token.
image_seq_len (`int`): The length of the image sequence.
image_token (`str`): The image token.
"""
return f"{image_token * image_seq_len}{bos_token}{prompt}"
class PaliGemmaProcessor(ProcessorMixin):
r"""
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
[`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information.
Args:
image_processor ([`SiglipImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
if not hasattr(image_processor, "image_seq_length"):
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
self.image_seq_length = image_processor.image_seq_length
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
tokenize_newline_separately (`bool`, defaults to `True`):
Adds a separately tokenized '\n' at the end of the prompt.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None:
raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
if text is None:
logger.warning_once(
"You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
)
if isinstance(text, List) and isinstance(images, List):
if len(images) < len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
)
if _is_str_or_image(text):
text = [text]
elif isinstance(text, list) and _is_str_or_image(text[0]):
pass
input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
)
for prompt in text
]
pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
)["pixel_values"]
if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens
if tokenize_newline_separately:
inputs = self.tokenizer(
input_strings,
add_special_tokens=False,
return_tensors=None,
padding="do_not_pad",
max_length=max_length,
truncation=truncation,
)
newline_token = self.tokenizer.convert_tokens_to_ids("\n")
concatenated_ids = [ids + [newline_token] for ids in inputs["input_ids"]]
concatenated_attention_masks = [mask + [1] for mask in inputs["attention_mask"]]
text_inputs = self.tokenizer.pad(
{"input_ids": concatenated_ids, "attention_mask": concatenated_attention_masks},
max_length=max_length,
padding=padding,
return_tensors=return_tensors,
)
else:
text_inputs = self.tokenizer(
input_strings,
add_special_tokens=False,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
)
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Union
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
convert_to_rgb,
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
...@@ -73,6 +74,8 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -73,6 +74,8 @@ class SiglipImageProcessor(BaseImageProcessor):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -87,6 +90,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -87,6 +90,7 @@ class SiglipImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -102,6 +106,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -102,6 +106,7 @@ class SiglipImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [ self._valid_processor_keys = [
"images", "images",
"do_resize", "do_resize",
...@@ -115,6 +120,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -115,6 +120,7 @@ class SiglipImageProcessor(BaseImageProcessor):
"return_tensors", "return_tensors",
"data_format", "data_format",
"input_data_format", "input_data_format",
"do_convert_rgb",
] ]
def preprocess( def preprocess(
...@@ -131,6 +137,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -131,6 +137,7 @@ class SiglipImageProcessor(BaseImageProcessor):
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -176,6 +183,8 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -176,6 +183,8 @@ class SiglipImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -186,6 +195,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -186,6 +195,7 @@ class SiglipImageProcessor(BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_list_of_images(images) images = make_list_of_images(images)
...@@ -209,6 +219,9 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -209,6 +219,9 @@ class SiglipImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
if is_scaled_image(images[0]) and do_rescale: if is_scaled_image(images[0]) and do_rescale:
logger.warning_once( logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input" "It looks like you are trying to rescale already rescaled images. If the input"
......
...@@ -881,6 +881,8 @@ class SiglipVisionTransformer(nn.Module): ...@@ -881,6 +881,8 @@ class SiglipVisionTransformer(nn.Module):
self.embeddings = SiglipVisionEmbeddings(config) self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config) self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config) self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
...@@ -915,14 +917,13 @@ class SiglipVisionTransformer(nn.Module): ...@@ -915,14 +917,13 @@ class SiglipVisionTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state) last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooler_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
...@@ -959,6 +960,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -959,6 +960,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class SiglipVisionModel(SiglipPreTrainedModel): class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig config_class = SiglipVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["SiglipVisionTransformer"]
def __init__(self, config: SiglipVisionConfig): def __init__(self, config: SiglipVisionConfig):
super().__init__(config) super().__init__(config)
......
...@@ -6135,6 +6135,27 @@ class OwlViTVisionModel(metaclass=DummyObject): ...@@ -6135,6 +6135,27 @@ class OwlViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PaliGemmaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PaliGemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PaliGemmaProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerForPrediction(metaclass=DummyObject): class PatchTSMixerForPrediction(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment