"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "692c3c6b73b8d4cb312950f60a05ab8ad37eff04"
Unverified Commit 1360801a authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Add PaliGemma (#30814)



* add new model like

* add state dict slicing + new model config

* update palma config and weights, passes vision activations

* fix

* update

* reorder loading/unpacking

* clean up

* add debug statements

* change device

* fix

* debugging

* fix noncausal mask

* fixup sdpa + causal mask

* fix activation function

* remove debug before changing modeling file

* add variants

* debug attention mask in generate

* revert to non-debug sdpa

* revert gemma modifications

* add custom language modeling

* use Processor

* add language modeling file to init

* try thin wrapper around generate

* Update

* update mask

* breakpoints galore

* remove conflict

* switch to left-padding

* add incomplete model doc

* add paligemma global files

* batch rename paligemma

* make generation match outputs and captioning

* style

* style

* remove copied from + doc

* remove more copied from

* remove copy from projector

* minor fix

* update config and style

* add readme - dummy

* CORRECT image captioning

* moving to args

* add siglip proper + fix merging image + text features

* take update_causal_mask from upstream

* remove breakpoint

* leverage AutoModel

* fix input_ids slicing

* make siglip head conditional

* remove encoder_decoder value

* remove unneeded modeling file

* add commented 4d attention mask

* FIXED generation with 4D mask

* Update src/transformers/models/siglip/modeling_siglip.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix left padding detection

* shuffle order of verifications

* fix missing labels for training

* fix

* vectorize merging of features, improve slicing

* improve testing before conversion

* handle merging in processor

* image token index depends on checkpoint

* add variants, save processor too

* save processors, base tokenizer off spm file

* expand model embeddings due to additional image token

* pass image processing args

* add convert rgb to siglip processor

* add \n token separately

* fix tokenizer and prompts

* fix docstrings

* change to camel

* fix casing

* debug pos_ids and sdpa

* pass and use cache_position

* add flag for newline tokenization

* Update src/transformers/models/paligemma/processing_paligemma.py
Co-authored-by: default avatarMerve Noyan <merveenoyan@gmail.com>

* simplify conversion script

* add copied from

* add precision to conversion script

* Update src/transformers/models/paligemma/modeling_paligemma.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* clean up

* Shift attention mask from `1:`

After discussion with @molbap

* add docs, fix quality

* quality, tied weights inheritance, and logits/label alignment

* fix more tests

* pass attn_implementation to language model correctly

* add SiglipVisionTransformer to no split modules

* skip paligemma test for sdpa dispatch to flash

* skip incompatible tests

* quality

* [broken archive maps]

* Apply suggestions

- remove archive lists
- style
- take shape of inputs_embeds for batch
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/utils/dummy_pt_objects.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* simplify conversion script

* add suggestions

* add suggestions

* add copied from

* fix

* move labels out

* revert

* fix

* remove placeholder labels if None

* use cache_position

* fix quality + docstrings

* fix quality

* fix paligemma 4d gemma mask incompatibility

* fix config docstring

* fix query and attn_mask dtype

---------
Co-authored-by: default avatarArthurZucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarMerve Noyan <merveenoyan@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent c96aca3a
...@@ -784,6 +784,8 @@ ...@@ -784,6 +784,8 @@
title: OWL-ViT title: OWL-ViT
- local: model_doc/owlv2 - local: model_doc/owlv2
title: OWLv2 title: OWLv2
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver - local: model_doc/perceiver
title: Perceiver title: Perceiver
- local: model_doc/pix2struct - local: model_doc/pix2struct
......
...@@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| [OPT](model_doc/opt) | ✅ | ✅ | ✅ | | [OPT](model_doc/opt) | ✅ | ✅ | ✅ |
| [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ | | [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ |
| [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ | | [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ |
| [PaliGemma](model_doc/paligemma) | ✅ | ❌ | ❌ |
| [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ | | [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ |
| [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ | | [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ |
| [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ | | [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# PaliGemma
## Overview
The PaliGemma model was proposed by Google. It is a 3B VLM composed by a Siglip-400m vision encoder and a Gemma-2B decoder linked by a multimodal linear projection. It is not a chat model with images. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models.
This model was contributed by [Molbap](https://huggingface.co/Molbap).
## PaliGemmaConfig
[[autodoc]] PaliGemmaConfig
## PaliGemmaProcessor
[[autodoc]] PaliGemmaProcessor
## PaliGemmaForConditionalGeneration
[[autodoc]] PaliGemmaForConditionalGeneration
- forward
...@@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
......
...@@ -582,6 +582,7 @@ _import_structure = { ...@@ -582,6 +582,7 @@ _import_structure = {
"OwlViTTextConfig", "OwlViTTextConfig",
"OwlViTVisionConfig", "OwlViTVisionConfig",
], ],
"models.paligemma": ["PaliGemmaConfig"],
"models.patchtsmixer": ["PatchTSMixerConfig"], "models.patchtsmixer": ["PatchTSMixerConfig"],
"models.patchtst": ["PatchTSTConfig"], "models.patchtst": ["PatchTSTConfig"],
"models.pegasus": [ "models.pegasus": [
...@@ -2651,6 +2652,13 @@ else: ...@@ -2651,6 +2652,13 @@ else:
"OwlViTVisionModel", "OwlViTVisionModel",
] ]
) )
_import_structure["models.paligemma"].extend(
[
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
"PaliGemmaProcessor",
]
)
_import_structure["models.patchtsmixer"].extend( _import_structure["models.patchtsmixer"].extend(
[ [
"PatchTSMixerForPrediction", "PatchTSMixerForPrediction",
...@@ -5126,6 +5134,9 @@ if TYPE_CHECKING: ...@@ -5126,6 +5134,9 @@ if TYPE_CHECKING:
OwlViTTextConfig, OwlViTTextConfig,
OwlViTVisionConfig, OwlViTVisionConfig,
) )
from .models.paligemma import (
PaliGemmaConfig,
)
from .models.patchtsmixer import ( from .models.patchtsmixer import (
PatchTSMixerConfig, PatchTSMixerConfig,
) )
...@@ -6956,6 +6967,11 @@ if TYPE_CHECKING: ...@@ -6956,6 +6967,11 @@ if TYPE_CHECKING:
OwlViTTextModel, OwlViTTextModel,
OwlViTVisionModel, OwlViTVisionModel,
) )
from .models.paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
PaliGemmaProcessor,
)
from .models.patchtsmixer import ( from .models.patchtsmixer import (
PatchTSMixerForPrediction, PatchTSMixerForPrediction,
PatchTSMixerForPretraining, PatchTSMixerForPretraining,
......
...@@ -173,6 +173,7 @@ from . import ( ...@@ -173,6 +173,7 @@ from . import (
opt, opt,
owlv2, owlv2,
owlvit, owlvit,
paligemma,
patchtsmixer, patchtsmixer,
patchtst, patchtst,
pegasus, pegasus,
......
...@@ -182,6 +182,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -182,6 +182,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("opt", "OPTConfig"), ("opt", "OPTConfig"),
("owlv2", "Owlv2Config"), ("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"), ("owlvit", "OwlViTConfig"),
("paligemma", "PaliGemmaConfig"),
("patchtsmixer", "PatchTSMixerConfig"), ("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"), ("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"), ("pegasus", "PegasusConfig"),
...@@ -464,6 +465,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -464,6 +465,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("opt", "OPT"), ("opt", "OPT"),
("owlv2", "OWLv2"), ("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"), ("owlvit", "OWL-ViT"),
("paligemma", "PaliGemma"),
("patchtsmixer", "PatchTSMixer"), ("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"), ("patchtst", "PatchTST"),
("pegasus", "Pegasus"), ("pegasus", "Pegasus"),
......
...@@ -93,6 +93,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -93,6 +93,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("oneformer", "OneFormerImageProcessor"), ("oneformer", "OneFormerImageProcessor"),
("owlv2", "Owlv2ImageProcessor"), ("owlv2", "Owlv2ImageProcessor"),
("owlvit", "OwlViTImageProcessor"), ("owlvit", "OwlViTImageProcessor"),
("paligemma", "CLIPImageProcessor"),
("perceiver", "PerceiverImageProcessor"), ("perceiver", "PerceiverImageProcessor"),
("pix2struct", "Pix2StructImageProcessor"), ("pix2struct", "Pix2StructImageProcessor"),
("poolformer", "PoolFormerImageProcessor"), ("poolformer", "PoolFormerImageProcessor"),
......
...@@ -313,6 +313,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -313,6 +313,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("nezha", "NezhaForPreTraining"), ("nezha", "NezhaForPreTraining"),
("nllb-moe", "NllbMoeForConditionalGeneration"), ("nllb-moe", "NllbMoeForConditionalGeneration"),
("openai-gpt", "OpenAIGPTLMHeadModel"), ("openai-gpt", "OpenAIGPTLMHeadModel"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("retribert", "RetriBertModel"), ("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"), ("roberta", "RobertaForMaskedLM"),
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
...@@ -697,6 +698,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( ...@@ -697,6 +698,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
......
...@@ -74,6 +74,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -74,6 +74,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("oneformer", "OneFormerProcessor"), ("oneformer", "OneFormerProcessor"),
("owlv2", "Owlv2Processor"), ("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"), ("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"), ("pix2struct", "Pix2StructProcessor"),
("pop2piano", "Pop2PianoProcessor"), ("pop2piano", "Pop2PianoProcessor"),
("sam", "SamProcessor"), ("sam", "SamProcessor"),
......
...@@ -331,6 +331,7 @@ else: ...@@ -331,6 +331,7 @@ else:
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
( (
"pegasus", "pegasus",
( (
......
...@@ -794,7 +794,6 @@ GEMMA_INPUTS_DOCSTRING = r""" ...@@ -794,7 +794,6 @@ GEMMA_INPUTS_DOCSTRING = r"""
"The bare Gemma Model outputting raw hidden-states without any specific head on top.", "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
class GemmaModel(GemmaPreTrainedModel): class GemmaModel(GemmaPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
...@@ -988,8 +987,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -988,8 +987,6 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and attention_mask.dim() == 4: if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask causal_mask = attention_mask
else: else:
causal_mask = torch.full( causal_mask = torch.full(
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_paligemma": ["PaliGemmaConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_paligemma"] = [
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
]
_import_structure["processing_paligemma"] = ["PaliGemmaProcessor"]
if TYPE_CHECKING:
from .configuration_paligemma import PaliGemmaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
)
from .processing_paligemma import PaliGemmaProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaliGemmamodel configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class PaliGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`PaliGemmaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 2048):
Dimension of the multimodal projection space.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden layer of the Language model.
Example:
```python
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a PaliGemma config
>>> text_config = GemmaConfig()
>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
>>> configuration = PaliGemmaConfig(vision_config, text_config)
>>> # Initializing a model from the paligemma-3b-224 style configuration
>>> model = PaliGemmaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "paligemma"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
intermediate_size=4096,
hidden_size=1152,
patch_size=14,
image_size=224,
num_hidden_layers=27,
num_attention_heads=16,
vocab_size=257152,
vision_use_head=False,
)
self.vocab_size = self.vocab_size
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.vocab_size = self.text_config.vocab_size
elif text_config is None:
self.text_config = CONFIG_MAPPING["gemma"](
hidden_size=2048,
num_hidden_layers=18,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1,
is_encoder_decoder=False,
)
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert PaliGemma checkpoints from the original repository.
"""
import argparse
import collections
import torch
from numpy import load
from transformers import (
AutoTokenizer,
GemmaTokenizer,
GemmaTokenizerFast,
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
PaliGemmaProcessor,
SiglipImageProcessor,
)
from transformers.tokenization_utils_base import AddedToken
from transformers.utils import logging
device = "cuda" # "cpu"
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# TODO add sequence length variations here
PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"]
def get_paligemma_config(variant: str, precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
if variant in PALIGEMMA_VARIANTS:
image_size = image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152 if variant != "2b-test" else 256000
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
else:
raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA_VARIANTS}")
return final_config
def slice_state_dict(state_dict, config):
# fmt: off
# patch embeddings
state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose(
3, 2, 0, 1
)
state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias")
# positional embeddings
state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale")
encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias")
encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale")
encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias")
encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel")
encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias")
encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel")
encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias")
encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel")
encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias")
encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel")
encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias")
encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel")
encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias")
encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel")
encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose()
state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias")
# multimodal projector
state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose()
state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias")
# text decoder (gemma)
embedding_vector = state_dict.pop("llm/embedder/input_embedding")
state_dict["language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w")
llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w")
llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w")
llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum")
llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale")
llm_post_attention_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale")
state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
return state_dict
def flatten_nested_dict(params, parent_key="", sep="/"):
items = []
for k, v in params.items():
k = k.removeprefix("params/")
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
@torch.no_grad()
def convert_paligemma_checkpoint(
checkpoint_path,
tokenizer_model_file,
pytorch_dump_folder_path,
variant: str,
precision: str,
do_convert_weights=False,
):
"""
Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed.
"""
config = get_paligemma_config(variant, precision=precision)
if do_convert_weights:
if variant == "2b-test":
# for the test model, the vocabulary was smaller
tokenizer_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
else:
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_file)
image_token = AddedToken("<image>", normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
# tokenizer.padding_side = 'right' # uncomment for testing purposes only.
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size}
image_processor.image_seq_length = config.vision_config.num_image_tokens
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
data = load(checkpoint_path)
state_dict = flatten_nested_dict(data)
del data
state_dict_transformers = slice_state_dict(state_dict, config)
del state_dict
model = PaliGemmaForConditionalGeneration(config).to(device).eval()
model.load_state_dict(state_dict_transformers)
del state_dict_transformers
else:
processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path)
model = (
PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa")
.to(device)
.eval()
)
model.config.text_config._attn_implementation = "sdpa"
# model expansion to get random embeds of image tokens
pad_shape = 64 # for performance reasons
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we resize the model
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[257152:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))),
dim=0,
)
model.save_pretrained(pytorch_dump_folder_path, max_shard_size="2GB", safe_serialization=True)
processor.save_pretrained(pytorch_dump_folder_path)
#
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path",
required=True,
type=str,
help="Path to the .npz checkpoint",
)
parser.add_argument(
"--tokenizer_model_file",
required=True,
type=str,
help="Path to the sentencepiece tokenizer.model file",
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
type=str,
help="Path to the output directory where model and processor will be saved.",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
parser.add_argument(
"--variant",
default="2b-test",
choices=PALIGEMMA_VARIANTS,
type=str,
help="String identifier of the paligemma variant to convert.",
)
parser.add_argument(
"--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights."
)
args = parser.parse_args()
convert_paligemma_checkpoint(
checkpoint_path=args.checkpoint_path,
tokenizer_model_file=args.tokenizer_model_file,
pytorch_dump_folder_path=args.pytorch_dump_folder_path,
variant=args.variant,
precision=args.precision,
do_convert_weights=args.do_convert_weights,
)
This diff is collapsed.
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for PaliGemma.
"""
import logging
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import AddedToken, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
logger = logging.getLogger(__name__)
IMAGE_TOKEN = "<image>"
# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
def _is_str_or_image(elem):
return isinstance(elem, (str)) or is_image_or_image_url(elem)
def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
build_string_from_input(
prompt="Prefix str"
bos_token="<s>",
image_seq_len=3,
image_token="<im>",
)
The output will be:
"<im><im><im><s>Initial str"
Args:
prompt (`List[Union[str, ImageInput]]`): The input prompt.
bos_token (`str`): The beginning of sentence token.
image_seq_len (`int`): The length of the image sequence.
image_token (`str`): The image token.
"""
return f"{image_token * image_seq_len}{bos_token}{prompt}"
class PaliGemmaProcessor(ProcessorMixin):
r"""
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
[`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information.
Args:
image_processor ([`SiglipImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
if not hasattr(image_processor, "image_seq_length"):
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
self.image_seq_length = image_processor.image_seq_length
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
tokenize_newline_separately (`bool`, defaults to `True`):
Adds a separately tokenized '\n' at the end of the prompt.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None:
raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
if text is None:
logger.warning_once(
"You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
)
if isinstance(text, List) and isinstance(images, List):
if len(images) < len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
)
if _is_str_or_image(text):
text = [text]
elif isinstance(text, list) and _is_str_or_image(text[0]):
pass
input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
)
for prompt in text
]
pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
)["pixel_values"]
if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens
if tokenize_newline_separately:
inputs = self.tokenizer(
input_strings,
add_special_tokens=False,
return_tensors=None,
padding="do_not_pad",
max_length=max_length,
truncation=truncation,
)
newline_token = self.tokenizer.convert_tokens_to_ids("\n")
concatenated_ids = [ids + [newline_token] for ids in inputs["input_ids"]]
concatenated_attention_masks = [mask + [1] for mask in inputs["attention_mask"]]
text_inputs = self.tokenizer.pad(
{"input_ids": concatenated_ids, "attention_mask": concatenated_attention_masks},
max_length=max_length,
padding=padding,
return_tensors=return_tensors,
)
else:
text_inputs = self.tokenizer(
input_strings,
add_special_tokens=False,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
)
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Union
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
convert_to_rgb,
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
...@@ -73,6 +74,8 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -73,6 +74,8 @@ class SiglipImageProcessor(BaseImageProcessor):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -87,6 +90,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -87,6 +90,7 @@ class SiglipImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -102,6 +106,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -102,6 +106,7 @@ class SiglipImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [ self._valid_processor_keys = [
"images", "images",
"do_resize", "do_resize",
...@@ -115,6 +120,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -115,6 +120,7 @@ class SiglipImageProcessor(BaseImageProcessor):
"return_tensors", "return_tensors",
"data_format", "data_format",
"input_data_format", "input_data_format",
"do_convert_rgb",
] ]
def preprocess( def preprocess(
...@@ -131,6 +137,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -131,6 +137,7 @@ class SiglipImageProcessor(BaseImageProcessor):
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -176,6 +183,8 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -176,6 +183,8 @@ class SiglipImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -186,6 +195,7 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -186,6 +195,7 @@ class SiglipImageProcessor(BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_list_of_images(images) images = make_list_of_images(images)
...@@ -209,6 +219,9 @@ class SiglipImageProcessor(BaseImageProcessor): ...@@ -209,6 +219,9 @@ class SiglipImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
if is_scaled_image(images[0]) and do_rescale: if is_scaled_image(images[0]) and do_rescale:
logger.warning_once( logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input" "It looks like you are trying to rescale already rescaled images. If the input"
......
...@@ -881,7 +881,9 @@ class SiglipVisionTransformer(nn.Module): ...@@ -881,7 +881,9 @@ class SiglipVisionTransformer(nn.Module):
self.embeddings = SiglipVisionEmbeddings(config) self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config) self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = SiglipMultiheadAttentionPoolingHead(config) self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
...@@ -915,14 +917,13 @@ class SiglipVisionTransformer(nn.Module): ...@@ -915,14 +917,13 @@ class SiglipVisionTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state) last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooler_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
...@@ -959,6 +960,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -959,6 +960,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class SiglipVisionModel(SiglipPreTrainedModel): class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig config_class = SiglipVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["SiglipVisionTransformer"]
def __init__(self, config: SiglipVisionConfig): def __init__(self, config: SiglipVisionConfig):
super().__init__(config) super().__init__(config)
......
...@@ -6135,6 +6135,27 @@ class OwlViTVisionModel(metaclass=DummyObject): ...@@ -6135,6 +6135,27 @@ class OwlViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PaliGemmaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PaliGemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PaliGemmaProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerForPrediction(metaclass=DummyObject): class PatchTSMixerForPrediction(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment