Unverified Commit caa0ff0b authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Add fuyu model (#26911)



* initial commit

* add processor, add fuyu naming

* add draft processor

* fix processor

* remove dropout to fix loading of weights

* add image processing fixes from Pedro

* fix

* fix processor

* add basic processing fuyu test

* add documentation and TODO

* address comments, add tests, add doc

* replace assert with torch asserts

* add Mixins and fix tests

* clean imports

* add model tester, clean imports

* fix embedding test

* add updated tests from pre-release model

* Processor: return input_ids used for inference

* separate processing and model tests

* relax test tolerance for embeddings

* add test for logit comparison

* make sure fuyu image processor is imported in the init

* fix formattingh

* more formatting issues

* and more

* fixups

* remove some stuff

* nits

* update init

* remove the fuyu file

* Update integration test with release model

* Update conversion script.

The projection is not used, as confirmed by the authors.

* improve geenration

* Remove duplicate function

* Trickle down patches to model call

* processing fuyu updates

* remove things

* fix prepare_inputs_for_generation to fix generate()

* remove model_input

* update

* add generation tests

* nits

* draft leverage automodel and autoconfig

* nits

* fix dtype patch

* address comments, update READMEs and doc, include tests

* add working processing test, remove refs to subsequences

* add tests, remove Sequence classification

* processing

* update

* update the conversion script

* more processing cleanup

* safe import

* take out ModelTesterMixin for early release

* more cl;eanup

* more cleanup

* more cleanup

* and more

* register a buffer

* nits

* add postprocessing of generate output

* nits

* updates

* add one working test

* fix test

* make fixup works

* fixup

* Arthur's updates

* nits

* update

* update

* fix processor

* update tests

* passe more fixups

* fix

* nits

* don't import torch

* skip fuyu config for now

* fixup done

* fixup

* update

* oups

* nits

* Use input embeddings

* no buffer

* update

* styling processing fuyu

* fix test

* update licence

* protect torch import

* fixup and update not doctested

* kwargs should be passed

* udpates

* update the impofixuprts in the test

* protect import

* protecting imports

* protect imports in type checking

* add testing decorators

* protect top level import structure

* fix typo

* fix check init

* move requires_backend to functions

* Imports

* Protect types

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarArthurZucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre@huggingface.co>
parent 5a73316b
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import warnings
import flatdict
import torch
from transformers import FuyuConfig, FuyuForCausalLM, LlamaTokenizer
try:
from transformers import LlamaTokenizerFast
tokenizer_class = LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
tokenizer_class = LlamaTokenizer
"""
Sample usage: # TODO fix clone links from persimmon to fuyu
```
git clone https://github.com/adept-ai-labs/adept-inference
wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_base_model_release.tar
wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_chat_model_release.tar
python src/transformers/models/fuyu/convert_fuyu_weights_to_hf.py --input_dir /path/to/downloaded/fuyu/weights/ --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import FuyuForCausalLM, FuyuTokenizer
model = FuyuForCausalLM.from_pretrained("/output/path")
tokenizer = FuyuTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
KEYS_TO_MODIFY_MAPPING = {
"self_attention": "self_attn",
"language_model.encoder": "language_model.model",
"word_embeddings_for_head": "language_model.lm_head",
"language_model.embedding.word_embeddings": "language_model.model.embed_tokens",
"vit_encoder.linear_encoder": "vision_embed_tokens",
}
KEYS_TO_REMOVE = {
"rotary_emb.inv_freq",
"image_patch_projection",
"image_patch_projection.weight",
"image_patch_projection.bias",
}
def rename_state_dict(state_dict):
model_state_dict = {}
for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
# if KEYS_TO_REMOVE in key:
if key in KEYS_TO_REMOVE:
continue
model_state_dict[key] = value
return model_state_dict
def convert_fuyu_checkpoint(pytorch_dump_folder_path, ada_lib_path, pt_model_path, safe_serialization=False):
sys.path.insert(0, ada_lib_path)
model_state_dict_base = torch.load(pt_model_path, map_location="cpu")
state_dict = flatdict.FlatDict(model_state_dict_base["model"], ".")
state_dict = rename_state_dict(state_dict)
transformers_config = FuyuConfig()
model = FuyuForCausalLM(transformers_config).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.save_pretrained(pytorch_dump_folder_path, safe_serialization=safe_serialization)
transformers_config.save_pretrained(pytorch_dump_folder_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Fuyu weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--pt_model_path",
help="Location of Fuyu `model_optim_rng.pt`",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--ada_lib_path",
help="Location of original source code from adept to deserialize .pt checkpoint",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "adept_vocab.model")
convert_fuyu_checkpoint(
pytorch_dump_folder_path=args.output_dir,
pt_model_path=args.pt_model_path,
safe_serialization=args.safe_serialization,
ada_lib_path=args.ada_lib_path,
)
tokenizer = tokenizer_class(spm_path, bos_token="|ENDOFTEXT|", eos_token="|ENDOFTEXT|")
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()
import math
from typing import List, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor
from ...image_transforms import (
normalize,
pad,
resize,
)
from ...image_utils import to_numpy_array
from ...utils import is_torch_available, is_vision_available, logging, requires_backends
if is_vision_available():
import PIL
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class FuyuImageProcessor(BaseImageProcessor):
"""
This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
handle:
- Processing Images:
Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
dimensions. The image output is always img_h ........................................... 1080 img_w
........................................... 1920 Then, it patches up these images using the patchify_image
function.
- Creating Image Input IDs:
For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
variable-sized images, each line of patches is terminated with a newline ID.
- Image Patch Indices:
For each image patch, the code maintains an index where these patches should be inserted in a token stream.
"""
model_input_names = [
"images",
"image_input_ids",
"image_patches",
"image_patch_indices_per_batch",
"image_patch_indices_per_subsequence",
]
def __init__(
self, target_height=1080, target_width=1920, padding_value=1.0, padding_mode: str = "constant", **kwargs
):
super().__init__(**kwargs)
self.target_width = target_width
self.target_height = target_height
self.padding_value = padding_value
self.padding_mode = padding_mode
def get_num_patches(self, img_h: int, img_w: int, patch_dim_h: int, patch_dim_w: int) -> int:
"""Calculate number of patches required to encode an image."""
if img_h % patch_dim_h != 0:
raise ValueError(f"{img_h=} must be divisible by {patch_dim_h=}")
if img_w % patch_dim_w != 0:
raise ValueError(f"{img_w=} must be divisible by {patch_dim_w=}")
num_patches_per_dim_h = img_h // patch_dim_h
num_patches_per_dim_w = img_w // patch_dim_w
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
return num_patches
def patchify_image(self, image: "torch.Tensor", patch_dim_h: int, patch_dim_w: int) -> "torch.Tensor":
"""
Convert an image into a tensor of patches.
Args:
image: Image to convert. Shape: [batch, channels, height, width]
patch_dim_h: Height of each patch.
patch_dim_w: Width of each patch.
"""
requires_backends(self, ["torch"])
# TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
# torch implementation is faster but does not handle non-squares
batch_size, channels, height, width = image.shape
unfolded_along_height = image.unfold(2, patch_dim_h, patch_dim_h)
patches = unfolded_along_height.unfold(3, patch_dim_w, patch_dim_w)
patches_reshaped = patches.contiguous().view(batch_size, channels, -1, patch_dim_h, patch_dim_w)
patches_final = patches_reshaped.permute(0, 2, 3, 4, 1).reshape(
batch_size, -1, channels * patch_dim_h * patch_dim_w
)
return patches_final
def process_images_for_model_input(
self,
image_input: "torch.Tensor",
image_present: "torch.Tensor",
image_unpadded_h: "torch.Tensor",
image_unpadded_w: "torch.Tensor",
image_patch_dim_h: int,
image_patch_dim_w: int,
image_placeholder_id: int,
image_newline_id: int,
variable_sized: bool,
) -> dict:
"""Process images for model input. In particular, variable-sized images are handled here.
Args:
image_input: [batch_size, 1, c, h, w] tensor of images padded to model input size.
image_present: [batch_size, 1] tensor of 1s and 0s indicating whether an image is present.
image_unpadded_h: [batch_size, 1] tensor of unpadded image heights.
image_unpadded_w: [batch_size, 1] tensor of unpadded image widths.
image_patch_dim_h: The height of the image patches.
image_patch_dim_w: The width of the image patches.
image_placeholder_id: The id of the image placeholder token.
image_newline_id: The id of the image newline token.
variable_sized: Whether to process images as variable-sized.
"""
requires_backends(self, ["torch"])
# Only images that are present.
images: List[List[torch.Tensor]] = []
image_patches: List[List[torch.Tensor]] = []
# Image input ids for every subsequence, including ones with no image present.
image_input_ids: List[List[torch.Tensor]] = []
for bi in range(image_input.shape[0]):
images.append([])
image_input_ids.append([])
image_patches.append([])
for si in range(image_input.shape[1]):
if image_present[bi, si]:
image = image_input[bi, si]
if variable_sized:
# The min() is required here due to floating point issues:
# math.ceil(torch.tensor(300).cuda() / 30) == 11
new_h = min(
image.shape[1], math.ceil(image_unpadded_h[bi, si] / image_patch_dim_h) * image_patch_dim_h
)
new_w = min(
image.shape[2], math.ceil(image_unpadded_w[bi, si] / image_patch_dim_w) * image_patch_dim_w
)
image = image[:, :new_h, :new_w]
images[bi].append(image)
num_patches = self.get_num_patches(
img_h=image.shape[1],
img_w=image.shape[2],
patch_dim_h=image_patch_dim_h,
patch_dim_w=image_patch_dim_w,
)
ids = torch.full([num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device)
patches = self.patchify_image(
image=image.unsqueeze(0), patch_dim_h=image_patch_dim_h, patch_dim_w=image_patch_dim_w
).squeeze(0)
if variable_sized:
# Now terminate each line with |NEWLINE|.
ids = ids.reshape(-1, new_w // image_patch_dim_w)
ids = torch.cat(
[
ids,
torch.full(
[ids.shape[0], 1], image_newline_id, dtype=torch.int32, device=image_input.device
),
],
dim=1,
)
ids = ids.reshape(-1)
image_input_ids[bi].append(ids)
image_patches[bi].append(patches)
else:
image_input_ids[bi].append(torch.tensor([], dtype=torch.int32, device=image_input.device))
# Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
# the stream.
image_patch_indices_per_batch: List[List[torch.Tensor]] = []
image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
for bi in range(len(image_input_ids)):
image_patch_indices_per_batch.append([])
image_patch_indices_per_subsequence.append([])
index_offset = 0
for si in range(len(image_input_ids[bi])):
# Indices of image patches.
num_patches = torch.count_nonzero(image_input_ids[bi][si] == image_placeholder_id)
indices = torch.arange(
num_patches,
dtype=image_input_ids[bi][si].dtype,
device=image_input_ids[bi][si].device,
)
# Place those indices in the image input ids token stream, with -1 representing non-index tokens.
indices_in_stream_per_batch = torch.full_like(image_input_ids[bi][si], -1)
indices_in_stream_per_subsequence = torch.full_like(image_input_ids[bi][si], -1)
indices_in_stream_per_batch[
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
] = (indices + index_offset)
indices_in_stream_per_subsequence[
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
] = indices
image_patch_indices_per_batch[bi].append(indices_in_stream_per_batch)
image_patch_indices_per_subsequence[bi].append(indices_in_stream_per_subsequence)
index_offset += num_patches
return {
"images": images,
"image_input_ids": image_input_ids,
"image_patches": image_patches,
"image_patch_indices_per_batch": image_patch_indices_per_batch,
"image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
}
def _scale_to_target_aspect_ratio(self, image: np.ndarray) -> np.ndarray:
image_height, image_width, _ = image.shape
if image_width <= self.target_width and image_height <= self.target_height:
return image
height_scale_factor = self.target_height / image_height
width_scale_factor = self.target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
new_height = int(image_height * optimal_scale_factor)
new_width = int(image_width * optimal_scale_factor)
scaled_image = resize(image=image, size=(new_width, new_height))
return np.array(scaled_image)
def _pad_to_target_size(self, image: np.ndarray) -> np.ndarray:
image_height, image_width, _ = image.shape
padding_top = 0
padding_left = 0
padding_bottom = self.target_height - image_height
padding_right = self.target_width - image_width
padded_image = pad(
image,
((padding_top, padding_bottom), (padding_left, padding_right)),
mode=self.padding_mode,
constant_values=self.padding_value,
)
return padded_image
def apply_transformation(self, image: Union[np.ndarray, PIL.Image.Image]) -> np.ndarray:
if isinstance(image, PIL.Image.Image):
image = to_numpy_array(image)
scaled_image = self._scale_to_target_aspect_ratio(image)
padded_image = self._pad_to_target_size(scaled_image)
normalized_padded_image = normalize(padded_image, 0.5, 0.5)
return normalized_padded_image
# coding=utf-8
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModelForCausalLM
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_fuyu import FuyuConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "FuyuConfig"
FUYU_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`FuyuConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Fuyu Model outputting raw hidden-states without any specific head on top.",
FUYU_START_DOCSTRING,
)
class FuyuPreTrainedModel(PreTrainedModel):
config_class = FuyuConfig
base_model_prefix = "fuyu"
supports_gradient_checkpointing = True
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, FuyuForCausalLM):
module.gradient_checkpointing = value
FUYU_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Fuyu Model outputting raw hidden-states without any specific head on top.",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FuyuDecoderLayer`]
Args:
config: FuyuConfig
"""
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
continuous_embeddings: List[torch.Tensor],
image_patch_input_indices: torch.Tensor,
) -> torch.Tensor:
"""This function places the continuous_embeddings into the word_embeddings at the locations
indicated by image_patch_input_indices. Different batch elements can have different numbers of continuous
embeddings.
Args:
word_embeddings: Tensor of word embeddings. Shape: [b, s, h]
continuous_embeddings:
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is
shape [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
indices in image_patch_input_indices for that batch element.
image_patch_input_indices: Tensor of indices of the image patches in the input_ids tensor. Shape: [b, s]
"""
if not (word_embeddings.shape[0] == len(continuous_embeddings)):
raise ValueError(
f"Batch sizes must match! Got {len(continuous_embeddings)=} and {word_embeddings.shape[0]=}"
)
output_embeddings = word_embeddings.clone()
for batch_idx in range(word_embeddings.shape[0]):
# First, find the positions of all the non-negative values in image_patch_input_indices, those are the
# positions in word_embeddings that we want to replace with content from continuous_embeddings.
dst_indices = torch.nonzero(image_patch_input_indices[batch_idx] >= 0, as_tuple=True)[0]
# Next look up those indices in image_patch_input_indices to find the indices in continuous_embeddings that we
# want to use to replace the values in word_embeddings.
src_indices = image_patch_input_indices[batch_idx][dst_indices]
# Check if we have more indices than embeddings. Note that we could have fewer indices if images got truncated.
if src_indices.shape[0] > continuous_embeddings[batch_idx].shape[0]:
raise ValueError(
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
)
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
return output_embeddings
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,
image_patch_input_indices=image_patches_indices,
)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
image_patches=None,
image_patches_indices=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if image_patches_indices is not None:
model_inputs["image_patches_indices"] = image_patches_indices
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"image_patches_indices": image_patches_indices if past_key_values is None else None,
"image_patches": image_patches if past_key_values is None else None,
}
)
return model_inputs
This diff is collapsed.
......@@ -3614,6 +3614,20 @@ def load_tf_weights_in_funnel(*args, **kwargs):
requires_backends(load_tf_weights_in_funnel, ["torch"])
class FuyuForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FuyuPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
GIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -219,6 +219,13 @@ class FlavaProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class FuyuImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class GLPNFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
import unittest
import numpy as np
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import (
require_torch,
require_torchvision,
require_vision,
)
if is_torch_available() and is_vision_available():
import torch
from transformers import FuyuImageProcessor
if is_vision_available():
from PIL import Image
@require_torch
@require_vision
@require_torchvision
class TestFuyuImageProcessor(unittest.TestCase):
def setUp(self):
self.processor = FuyuImageProcessor(target_height=160, target_width=320, padding_value=1.0)
self.batch_size = 3
self.channels = 3
self.height = 300
self.width = 300
self.image_input = torch.rand(self.batch_size, self.channels, self.height, self.width)
self.image_patch_dim_h = 30
self.image_patch_dim_w = 30
self.sample_image = np.zeros((450, 210, 3), dtype=np.uint8)
self.sample_image_pil = Image.fromarray(self.sample_image)
def test_patches(self):
expected_num_patches = self.processor.get_num_patches(
img_h=self.height, img_w=self.width, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
)
patches_final = self.processor.patchify_image(
image=self.image_input, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
)
assert (
patches_final.shape[1] == expected_num_patches
), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}."
def test_scale_to_target_aspect_ratio(self):
scaled_image = self.processor._scale_to_target_aspect_ratio(self.sample_image)
self.assertEqual(scaled_image.shape[0], 74)
self.assertEqual(scaled_image.shape[1], 160)
def test_apply_transformation_numpy(self):
transformed_image = self.processor.apply_transformation(self.sample_image)
self.assertEqual(transformed_image.shape[0], 160)
self.assertEqual(transformed_image.shape[1], 320)
def test_apply_transformation_pil(self):
transformed_image = self.processor.apply_transformation(self.sample_image_pil)
self.assertEqual(transformed_image.shape[0], 160)
self.assertEqual(transformed_image.shape[1], 320)
import io
import unittest
import requests
from transformers import AutoTokenizer, FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from ...test_modeling_common import ids_tensor, random_attention_mask
if is_vision_available():
from PIL import Image
if is_torch_available() and is_vision_available():
from transformers import FuyuImageProcessor, FuyuProcessor
if is_torch_available():
import torch
from transformers import FuyuForCausalLM
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Fuyu
class FuyuModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
image_size=300,
patch_size=30,
num_channels=3,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return FuyuConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = FuyuForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = FuyuForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = FuyuForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = FuyuForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
@require_torch_gpu
@slow
class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
"""
Currently, all these tests depend on a value of max_tokens_to_generate of 10.
"""
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
def setUp(self):
self.pretrained_model_name = "huggingface/new_model_release_weights"
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name)
image_processor = FuyuImageProcessor()
self.processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
self.bus_image_url = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
)
self.bus_image_pil = Image.open(io.BytesIO(requests.get(self.bus_image_url).content))
@slow
@require_torch_gpu
def test_model_8b_chat_greedy_generation_bus_captioning(self):
EXPECTED_TEXT_COMPLETION = """A bus parked on the side of a road.|ENDOFTEXT|"""
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
end_sequence = text[0].split("\x04")[1]
clean_sequence = (
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
if "|ENDOFTEXT|" in end_sequence
else end_sequence
)
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence[1:])
"""
@slow
@require_torch_gpu
def test_model_8b_chat_greedy_generation_bus_color(self):
EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|"
text_prompt_bus_color = "What color is the bus?\n"
model_inputs_bus_color = self.processor(text=text_prompt_bus_color, images=self.bus_image_pil)
generated_tokens = self.model.generate(**model_inputs_bus_color, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
end_sequence = text[0].split("\x04")[1]
clean_sequence = (
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
if "|ENDOFTEXT|" in end_sequence
else end_sequence
)
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence)
@slow
@require_torch_gpu
def test_model_8b_chat_greedy_generation_chart_vqa(self):
# fmt: off
EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",]
# fmt: on
expected_text_completion = " ".join(EXPECTED_TEXT_TOKENS) # TODO make sure the end string matches
text_prompt_chart_vqa = "What is the highest life expectancy at birth of male?\n"
chart_image_url = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/chart.png"
)
chart_image_pil = Image.open(io.BytesIO(requests.get(chart_image_url).content))
model_inputs_chart_vqa = self.processor(text=text_prompt_chart_vqa, images=chart_image_pil)
generated_tokens = self.model.generate(**model_inputs_chart_vqa, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
end_sequence = text[0].split("\x04")[1]
clean_sequence = (
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
if "|ENDOFTEXT|" in end_sequence
else end_sequence
)
self.assertEqual(expected_text_completion, clean_sequence)
@slow
@require_torch_gpu
def test_model_8b_chat_greedy_generation_bounding_box(self):
EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|"
text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231
bbox_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bbox_sample_image.png"
bbox_image_pil = Image.open(io.BytesIO(requests.get(bbox_image_url).content))
model_inputs_bbox = self.processor(text=text_prompt_bbox, images=bbox_image_pil)
generated_tokens = self.model.generate(**model_inputs_bbox, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
end_sequence = text[0].split("\x04")[1]
clean_sequence = (
end_sequence[: end_sequence.find("|ENDOFTEXT|") + len("|ENDOFTEXT|")]
if "|ENDOFTEXT|" in end_sequence
else end_sequence
)
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence)
"""
import io
import unittest
import requests
from transformers import AutoTokenizer, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow
if is_vision_available():
from PIL import Image
if is_vision_available() and is_torch_available():
from transformers import FuyuImageProcessor, FuyuProcessor
if is_torch_available():
import torch
from transformers.models.fuyu.processing_fuyu import construct_full_unpacked_stream, full_unpacked_stream_to_tensor
@require_torch
@require_torch_gpu
@slow
class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here?
""" """
def setUp(self):
pretrained_model_name = "huggingface/pre_release_model"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
image_processor = FuyuImageProcessor()
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
text_prompt = "Generate a coco-style caption.\\n"
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
self.one_image_bus_model_inputs = processor(text=text_prompt, images=bus_image_pil)
def test_fuyu_processing(self):
"""
Test to ensure that the standard processing on a gold example matches adept's code.
"""
# fmt: off
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
# fmt: on
torch.testing.assert_close(
self.one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS
)
torch.testing.assert_close(self.one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
@require_torch
class TestImageTextProcessingUtils(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.new_seq_len = 8
self.num_sub_sequences = 1
self.all_bi_tokens_to_place = [4, 6]
self.full_unpacked_stream = [torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6, 7, 8, 9, 10])]
self.fill_value = 0
self.num_real_text_tokens = [[3, 2], [2, 4]]
# Here the input stream is padded to avoid inconsistencies (current model release matches)
self.input_stream = torch.tensor([[[1, 2, 3], [4, 5, 0]], [[6, 7, 0], [8, 9, 10]]])
self.image_tokens = [
[torch.tensor([1, 2]), torch.tensor([3])],
[torch.tensor([4, 5, 6]), torch.tensor([7, 8])],
]
def test_full_unpacked_stream_to_tensor(self):
result = full_unpacked_stream_to_tensor(
self.all_bi_tokens_to_place,
self.full_unpacked_stream,
self.fill_value,
self.batch_size,
self.new_seq_len,
offset=0,
)
EXPECTED_TENSOR = torch.tensor([[1, 2, 3, 4, 0, 0, 0, 0], [5, 6, 7, 8, 9, 10, 0, 0]])
self.assertTrue(torch.equal(result, EXPECTED_TENSOR))
def test_construct_full_unpacked_stream(self):
result = construct_full_unpacked_stream(
self.num_real_text_tokens, self.input_stream, self.image_tokens, self.batch_size, self.num_sub_sequences
)
EXPECTED_UNPACKED_STREAM = [torch.tensor([1, 2, 1, 2, 3]), torch.tensor([4, 5, 6, 6, 7])]
for i in range(len(result)):
self.assertTrue(torch.equal(result[i], EXPECTED_UNPACKED_STREAM[i]))
@require_torch
class TestProcessImagesForModelInput(unittest.TestCase):
def setUp(self):
"""
Adding a mix of present and absent images.
"""
self.image_processor = FuyuImageProcessor()
self.image_input = torch.randn([1, 1, 3, 64, 64])
self.image_present = torch.tensor([[1]])
self.image_unpadded_h = torch.tensor([[45]]) # Adjusted for subsequence of 1
self.image_unpadded_w = torch.tensor([[50]]) # Adjusted for subsequence of 1
self.image_patch_dim_h = 16
self.image_patch_dim_w = 16
self.image_placeholder_id = 999
self.image_newline_id = 888
self.variable_sized = True
def test_process_images_for_model_input_fixed_sized(self):
self.variable_sized = False
result = self.image_processor.process_images_for_model_input(
image_input=self.image_input,
image_present=self.image_present,
image_unpadded_h=self.image_unpadded_h,
image_unpadded_w=self.image_unpadded_w,
image_patch_dim_h=self.image_patch_dim_h,
image_patch_dim_w=self.image_patch_dim_w,
image_placeholder_id=self.image_placeholder_id,
image_newline_id=self.image_newline_id,
variable_sized=self.variable_sized,
)
print(result["images"][0][0])
self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64]))
......@@ -36,6 +36,7 @@ SPECIAL_CASES_TO_ALLOW = {
"EncodecConfig": ["overlap"],
# used as `self.bert_model = BertModel(config, ...)`
"DPRConfig": True,
"FuyuConfig": True,
# not used in modeling files, but it's an important information
"FSMTConfig": ["langs"],
# used internally in the configuration class file
......
......@@ -79,6 +79,7 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"FuyuForCausalLM", # Not tested fort now
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"UMT5EncoderModel", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model.
......
......@@ -566,6 +566,7 @@ src/transformers/models/funnel/configuration_funnel.py
src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py
src/transformers/models/funnel/modeling_funnel.py
src/transformers/models/funnel/modeling_tf_funnel.py
src/transformers/models/fuyu/convert_fuyu_model_weights_to_hf.py
src/transformers/models/git/configuration_git.py
src/transformers/models/git/convert_git_to_pytorch.py
src/transformers/models/glpn/configuration_glpn.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment