Unverified Commit 6b78360e authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add Idefics2 (#30253)



* Initial add model additions

* Test

* All weights loading

* Can perform full forward pass

* Local and remote the same

* Matching local and remote

* Fixup

* Idefics2Model importable; fixup docstrings

* Don't skip by default

* Remove deprecated use_resampler arg

* Remove self.config

* DecoupledLinear takes config

* Tidy up

* Enable eager attention and tidy up

* Most tests passing

* Update for batch of processed images

* Add image processor

* Update doc pages

* Update conversion script

* Remove erroneous breakpoint

* Remove accidendtal spelling change

* Update to reflect changes on hub - make generate work

* Fix up

* Image processor tests

* Update tests

* Add a processor

* Add a processor

* Update convert script

* Update modeling file - remove fixmes

* Bug fix

* Add processing test

* Use processor

* Fix up

* Update src/transformers/models/idefics2/modeling_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update src/transformers/models/idefics2/modeling_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Fix test

* Update config - PR comments and defaults align with checkpoint

* Reviewer comments

* Add copied froms for flahs attention

* Update src/transformers/models/idefics2/modeling_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove qk_layer_norm and freeze_layers functionality

* Fix

* Remove freeze_layer options from config

* Sync with upstream main

* Fix attention shapes siglip

* Remove Llava-next refs - TO REBASE

* Use AutoModel for text model

* Add comment to explain vision embeddings

* Fix issue with tie_word_embeddings

* Address review comments

* Fix and fix up

* Chat templates for idefics

* Fix copies

* Fix

* Add layer norms to FA2

* Fix tests

* Apply suggestions from code review
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Fix

* Review comments

* Update src/transformers/models/idefics2/modeling_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update inputs merger

* Merge weights in correct order

* Update convert script

* Update src/transformers/models/idefics2/processing_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update template

* Model code examples (fix idefics too)

* More review comments

* Tidy up

* Update processing

* Fix attention mask preparation

* Update inputs_merger inputs

* Vectorize inputs_merger

* Update src/transformers/models/idefics2/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/idefics2/modeling_idefics2.py

* Review comments

* saying bye to the `qk_layer_norms`

* Simplify

* Update latents

* Remove erroneuous readme changes

* Return images when applying chat template

* Fix bug - prompt images are for a single sample

* Update src/transformers/models/idefics2/modeling_idefics2.py

* image splitting

* fix test

* some more comment

* some comment

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics2/image_processing_idefics2.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update processor

* Update model tests

* Update src/transformers/models/idefics2/processing_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update src/transformers/models/idefics2/processing_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Don't add BOS in template

* Update src/transformers/models/idefics2/processing_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Remove index in examples

* Update tests to reflect #13

* Update src/transformers/models/idefics2/processing_idefics2.py
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* PR comment - consistent typing

* Update readme and model doc

* Update docs

* Update checkpoint references

* Update examples

* Fix and update tests

* Small addition

* Update tests - remove copied from as no ignore placement copy could be found

* Update example

* small fixes

* Update docs/source/en/model_doc/idefics2.md
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update docs/source/en/model_doc/idefics2.md
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Update README.md
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>

* Connector model as bridge

* Fix up

* Fix up

* Don't pass model inputs for generation kwargs update

* IDEFICS-2 -> Idefics2

* Remove config archive name

* IDEFICS-2 -> Idefics2

* Add back llava-next

* Update readmes

* Add requirements for processor tester

* Use custom convert_to_rgb to avoid possible BC

* Fix doc example

* Fix doc example

* Skip model doc tests - as model to large

* More doc example - account for image splitting

* Update src/transformers/image_transforms.py

* Fix config doctest

---------
Co-authored-by: default avatarPablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: default avatarArthurZucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarVictor SANH <victorsanh@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 667939a2
......@@ -111,6 +111,7 @@ from . import (
hubert,
ibert,
idefics,
idefics2,
imagegpt,
informer,
instructblip,
......
......@@ -125,6 +125,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("hubert", "HubertConfig"),
("ibert", "IBertConfig"),
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
......@@ -283,6 +284,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
]
)
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
......@@ -390,6 +392,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("hubert", "Hubert"),
("ibert", "I-BERT"),
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),
......
......@@ -71,6 +71,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("grounding-dino", "GroundingDinoImageProcessor"),
("groupvit", "CLIPImageProcessor"),
("idefics", "IdeficsImageProcessor"),
("idefics2", "Idefics2ImageProcessor"),
("imagegpt", "ImageGPTImageProcessor"),
("instructblip", "BlipImageProcessor"),
("kosmos-2", "CLIPImageProcessor"),
......
......@@ -120,6 +120,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("hubert", "HubertModel"),
("ibert", "IBertModel"),
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jukebox", "JukeboxModel"),
......@@ -287,6 +288,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("ibert", "IBertForMaskedLM"),
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
......@@ -678,6 +680,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("git", "GitForCausalLM"),
("idefics2", "Idefics2ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
......
......@@ -61,6 +61,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("groupvit", "CLIPProcessor"),
("hubert", "Wav2Vec2Processor"),
("idefics", "IdeficsProcessor"),
("idefics2", "Idefics2Processor"),
("instructblip", "InstructBlipProcessor"),
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
......
......@@ -201,6 +201,7 @@ else:
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("jukebox", ("JukeboxTokenizer", None)),
(
......
......@@ -1458,18 +1458,27 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
Example:
```python
>>> from transformers import AutoTokenizer, IdeficsForVisionText2Text
>>> from transformers import AutoProcessor, IdeficsForVisionText2Text
>>> model = IdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
>>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics-9b")
>>> dogs_image_url_1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg"
>>> dogs_image_url_2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image2.jpeg"
>>> prompts = [
... [
... "User:",
... dogs_image_url_1,
... "Describe this image.\nAssistant: An image of two dogs.\n",
... "User:",
... dogs_image_url_2,
... "Describe this image.\nAssistant:",
... ]
... ]
>>> inputs = processor(prompts, return_tensors="pt")
>>> generate_ids = model.generate(**inputs, max_new_tokens=6)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_idefics2": ["Idefics2Config"]}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_idefics2"] = ["Idefics2ImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_idefics2"] = [
"IDEFICS2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Idefics2ForConditionalGeneration",
"Idefics2PreTrainedModel",
"Idefics2Model",
]
_import_structure["processing_idefics2"] = ["Idefics2Processor"]
if TYPE_CHECKING:
from .configuration_idefics2 import Idefics2Config
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_idefics2 import Idefics2ImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_idefics2 import (
IDEFICS2_PRETRAINED_MODEL_ARCHIVE_LIST,
Idefics2ForConditionalGeneration,
Idefics2Model,
Idefics2PreTrainedModel,
)
from .processing_idefics2 import Idefics2Processor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Idefics2 model configuration"""
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class Idefics2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Idefics2VisionModel`]. It is used to instantiate a
Idefics2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics2 model
[HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
intializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing all weight matrices in the model.
Example:
```python
>>> from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
>>> from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
>>> # Initializing a Idefics2VisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = Idefics2VisionConfig()
>>> # Initializing a Idefics2VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = Idefics2VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics2"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=32,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from Idefics2Config
if config_dict.get("model_type") == "idefics2":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Idefics2PerceiverConfig(PretrainedConfig):
r"""
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the perceiver block.
resampler_n_latents (`int`, *optional*, defaults to 64):
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
resampler_depth (`int`, *optional*, defaults to 3):
Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
resampler_n_heads (`int`, *optional*, defaults to 16):
Number of heads in each Transformer block (for multi-headed self-attention).
resampler_head_dim (`int`, *optional*, defaults to 96):
Dimensionality of each head projection in the Transformer block.
num_key_value_heads (`int`, *optional*, defaults to 4):
Number of key-value heads in the perceiver attention block.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type = "idefics2"
def __init__(
self,
hidden_act="silu",
resampler_n_latents=64,
resampler_depth=3,
resampler_n_heads=16,
resampler_head_dim=96,
num_key_value_heads=4,
attention_dropout=0.0,
**kwargs,
):
self.hidden_act = hidden_act
self.resampler_n_latents = resampler_n_latents
self.resampler_depth = resampler_depth
self.resampler_n_heads = resampler_n_heads
self.num_key_value_heads = num_key_value_heads
self.resampler_head_dim = resampler_head_dim
self.attention_dropout = attention_dropout
if self.num_key_value_heads > self.resampler_n_heads:
raise ValueError(
f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
f" resampler_n_heads={self.resampler_n_heads}"
)
super().__init__(**kwargs)
class Idefics2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Idefics2Model`]. It is used to instantiate a
Idefics2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the model of the Idefics2
[HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should cache the key/value pairs of the attention mechanism.
image_token_id (`int`, *optional*, defaults to 32001):
The id of the "image" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to tie the word embeddings with the token embeddings.
vision_config (`IdeficsVisionConfig` or `dict`, *optional*):
Custom vision config or dict
perceiver_config (`IdeficsPerceiverConfig` or `dict`, *optional*):
Custom perceiver config or dict
text_config (`MistralConfig` or `dict`, *optional*):
Custom text config or dict for the text model
Example:
```python
>>> from transformers import Idefics2Model, Idefics2Config
>>> # Initializing configuration
>>> configuration = Idefics2Config()
>>> # Initializing a model from the configuration
>>> model = Idefics2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics2"
is_composition = True
def __init__(
self,
use_cache=True,
image_token_id=32_001,
tie_word_embeddings=False,
vision_config=None,
perceiver_config=None,
text_config=None,
**kwargs,
):
self.image_token_id = image_token_id
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
if perceiver_config is None:
self.perceiver_config = Idefics2PerceiverConfig()
logger.info("perciver_config is None, using default perceiver config")
elif isinstance(perceiver_config, dict):
self.perceiver_config = Idefics2PerceiverConfig(**perceiver_config)
elif isinstance(perceiver_config, Idefics2PerceiverConfig):
self.perceiver_config = perceiver_config
if vision_config is None:
self.vision_config = Idefics2VisionConfig()
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
self.vision_config = Idefics2VisionConfig(**vision_config)
elif isinstance(vision_config, Idefics2VisionConfig):
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["mistral"](
max_position_embeddings=4096 * 8,
rms_norm_eps=1e-5,
# None in the original configuration_mistral, we set it to the unk_token_id
pad_token_id=0,
tie_word_embeddings=False,
)
self.text_config = text_config
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import torch
from accelerate import init_empty_weights
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Idefics2Config,
Idefics2ForConditionalGeneration,
Idefics2ImageProcessor,
Idefics2Processor,
MistralConfig,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/idefics2/convert_idefics2_weights_to_hf.py --original_model_id HuggingFaceM4/idefics2-8b --output_hub_path org/idefics2
"""
KEYS_TO_MODIFY_MAPPING = {
"lm_head.weight": "lm_head.linear.weight",
"model.layers": "model.text_model.layers",
"model.norm": "model.text_model.norm",
"model.perceiver_resampler": "model.connector.perceiver_resampler",
"model.modality_projection": "model.connector.modality_projection",
}
WEIGHTS_TO_MERGE_MAPPING = (
# (weights to merge in merging order), (new weight name)
(
("model.embed_tokens.weight", "model.embed_tokens.additional_embedding.weight"),
"model.text_model.embed_tokens.weight",
),
(("lm_head.linear.weight", "additional_fc.weight"), "lm_head.weight"),
)
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def merge_weights(state_dict):
new_state_dict = copy.deepcopy(state_dict)
# Merge the weights
for weights_to_merge, new_weight_name in WEIGHTS_TO_MERGE_MAPPING:
for weight in weights_to_merge:
assert weight in state_dict, f"Weight {weight} is missing in the state dict"
if new_weight_name not in new_state_dict:
new_state_dict[new_weight_name] = [state_dict[weight]]
else:
new_state_dict[new_weight_name].append(state_dict[weight])
new_state_dict[new_weight_name] = torch.cat(new_state_dict[new_weight_name], dim=0)
# Remove the weights that were merged
for weights_to_merge, new_weight_name in WEIGHTS_TO_MERGE_MAPPING:
for weight in weights_to_merge:
if weight in new_state_dict and weight != new_weight_name:
new_state_dict.pop(weight)
return new_state_dict
def get_config(checkpoint):
if checkpoint == "HuggingFaceM4/idefics2":
# We load the config then recreate to use the text_config
config = AutoConfig.from_pretrained(checkpoint)
text_config = MistralConfig(
vocab_size=config.vocab_size + config.additional_vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_act=config.hidden_act,
max_position_embeddings=config.max_position_embeddings,
initializer_range=config.initializer_range,
rms_norm_eps=config.rms_norm_eps,
tie_word_embeddings=config.tie_word_embeddings,
rope_theta=config.rope_theta,
sliding_window=config.sliding_window,
attention_dropout=config.attention_dropout,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
)
perceiver_config = config.perceiver_config.to_dict()
config = Idefics2Config(
text_config=text_config.to_dict(),
vision_config=config.vision_config,
perceiver_config=perceiver_config,
use_cache=config.use_cache,
image_token_id=config.image_token_id,
tie_word_embeddings=config.tie_word_embeddings,
)
return config
return AutoConfig.from_pretrained(checkpoint)
def convert_idefics2_hub_to_hf(original_model_id, output_hub_path, push_to_hub):
# The original model maps to AutoModelForCausalLM, converted we map to Idefics2ForConditionalGeneration
original_model = AutoModelForCausalLM.from_pretrained(original_model_id, trust_remote_code=True)
# The original model doesn't use the idefics2 processing objects
image_seq_len = original_model.config.perceiver_config.resampler_n_latents
image_processor = Idefics2ImageProcessor()
tokenizer = AutoTokenizer.from_pretrained(original_model_id)
processor = Idefics2Processor(
image_processor=image_processor,
tokenizer=tokenizer,
image_seq_len=image_seq_len,
)
state_dict = original_model.state_dict()
state_dict = convert_state_dict_to_hf(state_dict)
# Merge weights
state_dict = merge_weights(state_dict)
config = get_config(original_model_id)
with init_empty_weights():
model = Idefics2ForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
model.save_pretrained(output_hub_path)
processor.save_pretrained(output_hub_path)
if push_to_hub:
model.push_to_hub(output_hub_path, private=True)
processor.push_to_hub(output_hub_path, private=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--original_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--output_hub_path",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="If set, the model will be pushed to the hub after conversion.",
)
args = parser.parse_args()
convert_idefics2_hub_to_hf(args.original_model_id, args.output_hub_path, args.push_to_hub)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for IDEFICS2.
"""
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...utils import TensorType, logging
if TYPE_CHECKING:
from ...pipelines.conversational import Conversation
from ...tokenization_utils_base import PreTokenizedInput
logger = logging.get_logger(__name__)
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
class Idefics2Processor(ProcessorMixin):
r"""
Constructs a IDEFICS2 processor which wraps a LLama tokenizer and IDEFICS2 image processor into a single processor.
[`IdeficsProcessor`] offers all the functionalities of [`Idefics2ImageProcessor`] and [`LlamaTokenizerFast`]. See
the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
Args:
image_processor (`Idefics2ImageProcessor`):
An instance of [`Idefics2ImageProcessor`]. The image processor is a required input.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
image_seq_len (`int`, *optional*, defaults to 64):
The length of the image sequence i.e. the number of <image> tokens per image in the input.
This parameter is used to build the string from the input prompt and image tokens and should match the
config.perceiver_config.resampler_n_latents value for the model used.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "Idefics2ImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
self.image_token = AddedToken("<image>", normalized=False, special=True)
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
self.image_seq_len = image_seq_len
tokens_to_add = {
"additional_special_tokens": [self.fake_image_token, self.image_token, self.end_of_utterance_token]
}
tokenizer.add_special_tokens(tokens_to_add)
# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)
super().__init__(image_processor, tokenizer)
def _extract_images_from_prompts(self, prompts):
prompt_images = []
for prompt in prompts:
images = []
for elem in prompt:
if is_valid_image(elem):
images.append(elem)
elif is_url(elem):
images.append(load_image(elem))
prompt_images.append(images)
return prompt_images
def __call__(
self,
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
image_seq_len: Optional[int] = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
is_split_into_words: bool = False,
add_special_tokens: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchEncoding:
"""
Processes the input prompts and returns a BatchEncoding.
Example:
```python
>>> import requests
>>> from transformers import Idefics2Processor
>>> from transformers.image_utils import load_image
>>> processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
>>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
>>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
>>> image1, image2 = load_image(url1), load_image(url2)
>>> images = [[image1], [image2]]
>>> text = [
... "<image>In this image, we see",
... "bla bla bla<image>",
... ]
>>> outputs = processor(text=text, images=images, return_tensors="pt", padding=True)
>>> input_ids = outputs.input_ids
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
>>> print(input_tokens)
['<s><fake_token_around_image><image><image><fake_token_around_image> In this image, we see', '<s> bla bla bla<fake_token_around_image><image><image><fake_token_around_image>']
```
Args:
text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
Wherever an image token, `<image>` is encountered it is expanded to
`<fake_token_around_image>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
image_seq_len (`int`, *optional*):
The length of the image sequence. If not provided, the default value is used.
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `False`):
Padding strategy applied to the input ids. See [`PreTrainedTokenizerFast.pad`] for more information.
truncation (`Union[bool, str, TruncationStrategy]`, *optional*):
Truncation strategy applied to the input ids. See [`PreTrainedTokenizerFast.truncate`] for more information.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding/truncation length. See
[`PreTrainedTokenizerFast.__call__`] for more information.
is_split_into_words (`bool`, *optional*, defaults to `False`):
Whether the input text is split into words or not. If set to `True`, the tokenizer will skip the
tokenization process and assume the input is already tokenized.
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether to add special tokens or not. See [`PreTrainedTokenizerFast.__call__`] for more information.
return_tensors (`Union[str, TensorType]`, *optional*):
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
information.
"""
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
n_images_in_text = []
inputs = BatchFeature()
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
image_str = f"{fake_image_token}{image_token * image_seq_len}{fake_image_token}"
if self.image_processor.do_image_splitting:
# A single image token is split into 4 patches + 1 original image
image_str = image_str * 5
prompt_strings = []
for sample in text:
n_images_in_text.append(sample.count(image_token))
sample = sample.replace(image_token, image_str)
# Remove any double fake tokens if images are adjacent
sample = sample.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
prompt_strings.append(sample)
text_inputs = self.tokenizer(
text=prompt_strings,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
is_split_into_words=is_split_into_words,
return_tensors=return_tensors,
)
inputs.update(text_inputs)
if images is not None:
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
n_images_in_images = [len(sample) for sample in images]
if text is not None and not n_images_in_images == n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)
# Load images if they are URLs
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, return_tensors=return_tensors)
inputs.update(image_inputs)
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
chat_template: Optional[str] = None,
tokenize: bool = False,
**kwargs,
) -> str:
"""
Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default
if no chat template is provided.
By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert
the image token <image> into the sequence according to the message, but does not handle expanding the image
tokens to the sequence length or adding the surrounding tokens e.g. <fake_image_token>.
Args:
conversation (`Union[List[Dict, str, str], "Conversation"]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments for the tokenizer's `apply_chat_template` method.
"""
if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
else:
chat_template = self.default_chat_template
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)
@property
def default_chat_template(self):
"""
This template formats inputs in the form of a chat history. For each message in the chat history:
* the template will output the role of the speaker followed by the content of the message.
* content can be a single string or a list of strings and images.
* If the content element is an image, the template will output a sequence of <image> tokens and <fake_token_around_image> token before and after each image
* The template will output an <end_of_utterance> token at the end of each message.
Example:
```python
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "What’s in this image?"},
{"type": "image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. Idefix is running on the ground."},]
}]
```
Will create outputs like:
```
User: What is in this Image?<image><image><end_of_utterance>
Assistant: This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. Idefix is running on the ground.<end_of_utterance>
```
"""
# fmt: off
return (
"{% for message in messages %}"
"{{message['role'].capitalize()}}"
"{% if message['content'][0]['type'] == 'image' %}"
"{{':'}}"
"{% else %}"
"{{': '}}"
"{% endif %}"
"{% for line in message['content'] %}"
"{% if line['type'] == 'text' %}"
"{{line['text']}}"
"{% elif line['type'] == 'image' %}"
"{{ '<image>' }}"
"{% endif %}"
"{% endfor %}"
"<end_of_utterance>\n"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ 'Assistant:' }}"
"{% endif %}"
)
# fmt: on
......@@ -4405,6 +4405,37 @@ class IdeficsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
IDEFICS2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Idefics2ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics2Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics2Processor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -261,6 +261,13 @@ class IdeficsImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Idefics2ImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ImageGPTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
if is_vision_available():
from PIL import Image
from transformers import Idefics2ImageProcessor
if is_torch_available():
import torch
class Idefics2ImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
num_images=1,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_convert_rgb=True,
do_pad=True,
do_image_splitting=True,
):
size = size if size is not None else {"shortest_edge": 378, "longest_edge": 980}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.num_images = num_images
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_convert_rgb = do_convert_rgb
self.do_pad = do_pad
self.do_image_splitting = do_image_splitting
def prepare_image_processor_dict(self):
return {
"do_convert_rgb": self.do_convert_rgb,
"do_resize": self.do_resize,
"size": self.size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_pad": self.do_pad,
"do_image_splitting": self.do_image_splitting,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to BridgeTowerImageProcessor,
assuming do_resize is set to True with a scalar size and size_divisor.
"""
if not batched:
shortest_edge = self.size["shortest_edge"]
longest_edge = self.size["longest_edge"]
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
aspect_ratio = w / h
if w > h and w >= longest_edge:
w = longest_edge
h = int(w / aspect_ratio)
elif h > w and h >= longest_edge:
h = longest_edge
w = int(h * aspect_ratio)
w = max(w, shortest_edge)
h = max(h, shortest_edge)
expected_height = h
expected_width = w
else:
expected_values = []
for images in image_inputs:
for image in images:
expected_height, expected_width = self.get_expected_values([image])
expected_values.append((expected_height, expected_width))
expected_height = max(expected_values, key=lambda item: item[0])[0]
expected_width = max(expected_values, key=lambda item: item[1])[1]
return expected_height, expected_width
def expected_output_image_shape(self, images):
height, width = self.get_expected_values(images, batched=True)
effective_nb_images = self.num_images * 5 if self.do_image_splitting else 1
return effective_nb_images, self.num_channels, height, width
def prepare_image_inputs(
self,
batch_size=None,
min_resolution=None,
max_resolution=None,
num_channels=None,
num_images=None,
size_divisor=None,
equal_resolution=False,
numpify=False,
torchify=False,
):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
One can specify whether the images are of the same resolution or not.
"""
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
batch_size = batch_size if batch_size is not None else self.batch_size
min_resolution = min_resolution if min_resolution is not None else self.min_resolution
max_resolution = max_resolution if max_resolution is not None else self.max_resolution
num_channels = num_channels if num_channels is not None else self.num_channels
num_images = num_images if num_images is not None else self.num_images
images_list = []
for i in range(batch_size):
images = []
for j in range(num_images):
if equal_resolution:
width = height = max_resolution
else:
# To avoid getting image width/height 0
if size_divisor is not None:
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
min_resolution = max(size_divisor, min_resolution)
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
images_list.append(images)
if not numpify and not torchify:
# PIL expects the channel dimension as last dimension
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
if torchify:
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
return images_list
@require_torch
@require_vision
class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Idefics2ImageProcessor if is_vision_available() else None
def setUp(self):
self.image_processor_tester = Idefics2ImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
This diff is collapsed.
This diff is collapsed.
......@@ -60,6 +60,7 @@ from transformers.models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from transformers.testing_utils import (
......@@ -220,6 +221,7 @@ class ModelTesterMixin:
*get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment