Unverified Commit 8a312956 authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Fuyu: improve image processing (#27007)



* Fix Fuyu image scaling bug

It could produce negative padding and hence inference errors for certain
image sizes.

* initial rework commit

* add batching capabilities, refactor image processing

* add functional batching for a list of images and texts

* make args explicit

* Fuyu processing update (#27133)

* Add file headers

* Add file headers

* First pass - preprocess method with standard args

* First pass image processor rework

* Small tweaks

* More args and docstrings

* Tidying iterating over batch

* Tidying up

* Modify to have quick tests (for now)

* Fix up

* BatchFeature

* Passing tests

* Add tests for processor

* Sense check when patchifying

* Add some tests

* FuyuBatchFeature

* Post-process box coordinates

* Update to `size` in processor

* Remove unused and duplicate constants

* Store unpadded dims after resize

* Fix up

* Return FuyuBatchFeature

* Get unpadded sizes after resize

* Update exception

* Fix return

* Convert input `<box>` coordinates to model format.

* Post-process point coords, support multiple boxes/points in a single
sequence

* Replace constants

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Preprocess List[List[image]]

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update to Amy's latest state.

* post-processing returns a list of tensors

* Fix error when target_sizes is None
Co-authored-by: default avatarPablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Review comments

* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Fix up

* Fix up

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Fix conflicts in fuyu_follow_up_image_processing (#27228)

fixing conflicts and updating on main

* Revert "Fix conflicts in fuyu_follow_up_image_processing" (#27232)

Revert "Fix conflicts in fuyu_follow_up_image_processing (#27228)"

This reverts commit acce10b6c653dc7041fb9d18cfed55775afd6207.

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-72-126.ec2.internal>
parent 9b25c164
...@@ -112,17 +112,9 @@ class BatchFeature(UserDict): ...@@ -112,17 +112,9 @@ class BatchFeature(UserDict):
def items(self): def items(self):
return self.data.items() return self.data.items()
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
"""
Convert the inner content to tensors.
Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
"""
if tensor_type is None: if tensor_type is None:
return self return None, None
# Convert to TensorType # Convert to TensorType
if not isinstance(tensor_type, TensorType): if not isinstance(tensor_type, TensorType):
...@@ -167,6 +159,21 @@ class BatchFeature(UserDict): ...@@ -167,6 +159,21 @@ class BatchFeature(UserDict):
return np.asarray(value, dtype=dtype) return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array is_tensor = is_numpy_array
return is_tensor, as_tensor
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
"""
Convert the inner content to tensors.
Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
"""
if tensor_type is None:
return self
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
# Do the tensor conversion in batch # Do the tensor conversion in batch
for key, value in self.items(): for key, value in self.items():
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Fuyu."""
import math import math
from typing import List, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from ...image_processing_utils import BaseImageProcessor from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import ( from ...image_transforms import (
normalize,
pad, pad,
resize, resize,
to_channel_dimension_format,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_list_of_images,
to_numpy_array,
)
from ...utils import (
TensorType,
is_torch_available,
is_torch_device,
is_torch_dtype,
logging,
requires_backends,
) )
from ...image_utils import to_numpy_array
from ...utils import is_torch_available, is_vision_available, logging, requires_backends
if is_vision_available():
import PIL
if is_torch_available(): if is_torch_available():
import torch import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def make_list_of_list_of_images(
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
) -> List[List[ImageInput]]:
if is_valid_image(images):
return [[images]]
if isinstance(images, list) and all(isinstance(image, list) for image in images):
return images
if isinstance(images, list):
return [make_list_of_images(image) for image in images]
raise ValueError("images must be a list of list of images or a list of images or an image.")
class FuyuBatchFeature(BatchFeature):
"""
BatchFeature class for Fuyu image processor and processor.
The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
"""
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
"""
Convert the inner content to tensors.
Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
"""
if tensor_type is None:
return self
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type)
def _convert_tensor(elem):
if is_tensor(elem):
return elem
return as_tensor(elem)
def _safe_convert_tensor(elem):
try:
return _convert_tensor(elem)
except: # noqa E722
if key == "overflowing_values":
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
)
# Do the tensor conversion in batch
for key, value in self.items():
if isinstance(value, list) and isinstance(value[0], list):
# List[List[Any]] -> List[List[Tensor]]
self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value]
elif isinstance(value, list):
# List[Any] -> List[Tensor]
self[key] = [_safe_convert_tensor(elem) for elem in value]
else:
# Any -> Tensor
self[key] = _safe_convert_tensor(value)
return self
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
def _to(elem):
# check if v is a floating point
if torch.is_floating_point(elem):
# cast and send to device
return elem.to(*args, **kwargs)
if device is not None:
return elem.to(device=device)
return elem
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
if isinstance(v, list) and isinstance(v[0], list):
# Data structure is a list of lists
new_v = []
for elems in v:
new_v.append([_to(elem) for elem in elems])
new_data[k] = new_v
elif isinstance(v, list):
# Data structure is a list
new_data[k] = [_to(elem) for elem in v]
else:
new_data[k] = _to(v)
self.data = new_data
return self
class FuyuImageProcessor(BaseImageProcessor): class FuyuImageProcessor(BaseImageProcessor):
""" """
This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
...@@ -29,9 +184,9 @@ class FuyuImageProcessor(BaseImageProcessor): ...@@ -29,9 +184,9 @@ class FuyuImageProcessor(BaseImageProcessor):
- Processing Images: - Processing Images:
Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
dimensions. The image output is always img_h ........................................... 1080 img_w dimensions. The image output is always img_h, img_w of (1080, 1920)
........................................... 1920 Then, it patches up these images using the patchify_image
function. Then, it patches up these images using the patchify_image function.
- Creating Image Input IDs: - Creating Image Input IDs:
For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
...@@ -40,6 +195,32 @@ class FuyuImageProcessor(BaseImageProcessor): ...@@ -40,6 +195,32 @@ class FuyuImageProcessor(BaseImageProcessor):
- Image Patch Indices: - Image Patch Indices:
For each image patch, the code maintains an index where these patches should be inserted in a token stream. For each image patch, the code maintains an index where these patches should be inserted in a token stream.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image to `size`.
size (`Dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to `size`.
padding_value (`float`, *optional*, defaults to 1.0):
The value to pad the image with.
padding_mode (`str`, *optional*, defaults to `"constant"`):
The padding mode to use when padding the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image.
image_mean (`float`, *optional*, defaults to 0.5):
The mean to use when normalizing the image.
image_std (`float`, *optional*, defaults to 0.5):
The standard deviation to use when normalizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `1 / 255`):
The factor to use when rescaling the image.
patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
""" """
model_input_names = [ model_input_names = [
...@@ -51,204 +232,483 @@ class FuyuImageProcessor(BaseImageProcessor): ...@@ -51,204 +232,483 @@ class FuyuImageProcessor(BaseImageProcessor):
] ]
def __init__( def __init__(
self, target_height=1080, target_width=1920, padding_value=1.0, padding_mode: str = "constant", **kwargs self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_pad: bool = True,
padding_value: float = 1.0,
padding_mode: str = "constant",
do_normalize: bool = True,
image_mean: Union[float, List[float]] = 0.5,
image_std: Union[float, List[float]] = 0.5,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
patch_size: Optional[Dict[str, int]] = None,
**kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.target_width = target_width self.do_resize = do_resize
self.target_height = target_height self.size = size if size is not None else {"height": 1080, "width": 1920}
self.resample = resample
self.do_pad = do_pad
self.padding_value = padding_value self.padding_value = padding_value
self.padding_mode = padding_mode self.padding_mode = padding_mode
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30}
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
def get_num_patches(self, img_h: int, img_w: int, patch_dim_h: int, patch_dim_w: int) -> int: Args:
"""Calculate number of patches required to encode an image.""" image (`np.ndarray`):
if img_h % patch_dim_h != 0: Image to resize.
raise ValueError(f"{img_h=} must be divisible by {patch_dim_h=}") size (`Dict[str, int]`):
if img_w % patch_dim_w != 0: Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
raise ValueError(f"{img_w=} must be divisible by {patch_dim_w=}") resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
image_height, image_width = get_image_size(image, input_data_format)
target_height, target_width = size["height"], size["width"]
num_patches_per_dim_h = img_h // patch_dim_h if image_width <= target_width and image_height <= target_height:
num_patches_per_dim_w = img_w // patch_dim_w return image
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
new_height = int(image_height * optimal_scale_factor)
new_width = int(image_width * optimal_scale_factor)
scaled_image = resize(
image=image,
size=(new_height, new_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
return scaled_image
def pad_image(
self,
image: np.ndarray,
size: Dict[str, int],
mode: str = "constant",
constant_values: float = 1.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to pad.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
data_format (`ChannelDimension` or `str`, *optional*):
The data format of the output image. If unset, the same format as the input image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
image_height, image_width = get_image_size(image, input_data_format)
target_height, target_width = size["height"], size["width"]
padding_top = 0
padding_left = 0
padding_bottom = target_height - image_height
padding_right = target_width - image_width
padded_image = pad(
image,
padding=((padding_top, padding_bottom), (padding_left, padding_right)),
mode=mode,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def preprocess(
self,
images,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
do_pad: Optional[bool] = None,
padding_value: Optional[float] = None,
padding_mode: Optional[str] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[float] = None,
image_std: Optional[float] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
patch_size: Optional[Dict[str, int]] = None,
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_tensors: Optional[TensorType] = None,
):
"""
Utility function to preprocess the images and extract necessary information about original formats.
Args:
images (`ImageInput`):
Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel
values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image to `size`.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to `size`.
padding_value (`float`, *optional*, defaults to `self.padding_value`):
The value to pad the image with.
padding_mode (`str`, *optional*, defaults to `self.padding_mode`):
The padding mode to use when padding the image.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float`, *optional*, defaults to `self.image_mean`):
The mean to use when normalizing the image.
image_std (`float`, *optional*, defaults to `self.image_std`):
The standard deviation to use when normalizing the image.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
The factor to use when rescaling the image.
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format of the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_pad = do_pad if do_pad is not None else self.do_pad
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
padding_value = padding_value if padding_value is not None else self.padding_value
padding_mode = padding_mode if padding_mode is not None else self.padding_mode
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
patch_size = patch_size if patch_size is not None else self.patch_size
if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
raise ValueError("Multiple images for a single sample are not yet supported.")
batch_images = make_list_of_list_of_images(images)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and image_mean is None or image_std is None:
raise ValueError("image_mean and image_std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]
if is_scaled_image(batch_images[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(batch_images[0][0])
original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
if do_resize:
batch_images = [
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
for images in batch_images
]
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
# scale_h is the same as scale_w
image_scale_factors = [
[resized_size[0] / original_size[0]]
for original_size, resized_size in zip(original_image_sizes, image_sizes)
]
if do_pad:
batch_images = [
[
self.pad_image(
image,
size=size,
mode=padding_mode,
constant_values=padding_value,
input_data_format=input_data_format,
)
for image in images
]
for images in batch_images
]
if do_rescale:
batch_images = [
[self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
for images in batch_images
]
if do_normalize:
batch_images = [
[
self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
for images in batch_images
]
if data_format is not None:
batch_images = [
[to_channel_dimension_format(image, data_format, input_data_format) for image in images]
for images in batch_images
]
data = {
"images": batch_images,
"image_unpadded_heights": image_unpadded_heights,
"image_unpadded_widths": image_unpadded_widths,
"image_scale_factors": image_scale_factors,
}
return FuyuBatchFeature(data=data, tensor_type=return_tensors)
def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int:
"""
Calculate number of patches required to encode an image.
Args:
image_height (`int`):
Height of the image.
image_width (`int`):
Width of the image.
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
"""
patch_size = patch_size if patch_size is not None else self.patch_size
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
if image_height % patch_height != 0:
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
if image_width % patch_width != 0:
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
num_patches_per_dim_h = image_height // patch_height
num_patches_per_dim_w = image_width // patch_width
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
return num_patches return num_patches
def patchify_image(self, image: "torch.Tensor", patch_dim_h: int, patch_dim_w: int) -> "torch.Tensor": def patchify_image(self, image: "torch.Tensor", patch_size: Optional[Dict[str, int]] = None) -> "torch.Tensor":
""" """
Convert an image into a tensor of patches. Convert an image into a tensor of patches.
Args: Args:
image: Image to convert. Shape: [batch, channels, height, width] image (`torch.Tensor`):
patch_dim_h: Height of each patch. Image to convert. Shape: [batch, channels, height, width]
patch_dim_w: Width of each patch. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
patch_size = patch_size if patch_size is not None else self.patch_size
patch_height, patch_width = patch_size["height"], patch_size["width"]
# TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871 # TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
# torch implementation is faster but does not handle non-squares # torch implementation is faster but does not handle non-squares
batch_size, channels, height, width = image.shape batch_size, channels, _, _ = image.shape
unfolded_along_height = image.unfold(2, patch_dim_h, patch_dim_h) unfolded_along_height = image.unfold(2, patch_height, patch_height)
patches = unfolded_along_height.unfold(3, patch_dim_w, patch_dim_w) patches = unfolded_along_height.unfold(3, patch_width, patch_width)
patches = patches.contiguous()
patches_reshaped = patches.contiguous().view(batch_size, channels, -1, patch_dim_h, patch_dim_w) patches = patches.view(batch_size, channels, -1, patch_height, patch_width)
patches = patches.permute(0, 2, 3, 4, 1)
patches_final = patches_reshaped.permute(0, 2, 3, 4, 1).reshape( patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width)
batch_size, -1, channels * patch_dim_h * patch_dim_w return patches
)
return patches_final
def process_images_for_model_input( def preprocess_with_tokenizer_info(
self, self,
image_input: "torch.Tensor", image_input: "torch.Tensor",
image_present: "torch.Tensor", image_present: "torch.Tensor",
image_unpadded_h: "torch.Tensor", image_unpadded_h: "torch.Tensor",
image_unpadded_w: "torch.Tensor", image_unpadded_w: "torch.Tensor",
image_patch_dim_h: int,
image_patch_dim_w: int,
image_placeholder_id: int, image_placeholder_id: int,
image_newline_id: int, image_newline_id: int,
variable_sized: bool, variable_sized: bool,
) -> dict: patch_size: Optional[Dict[str, int]] = None,
) -> FuyuBatchFeature:
"""Process images for model input. In particular, variable-sized images are handled here. """Process images for model input. In particular, variable-sized images are handled here.
Args: Args:
image_input: [batch_size, 1, c, h, w] tensor of images padded to model input size. image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]):
image_present: [batch_size, 1] tensor of 1s and 0s indicating whether an image is present. Tensor of images padded to model input size.
image_unpadded_h: [batch_size, 1] tensor of unpadded image heights. image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]):
image_unpadded_w: [batch_size, 1] tensor of unpadded image widths. Tensor of 1s and 0s indicating whether an image is present.
image_patch_dim_h: The height of the image patches. image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]):
image_patch_dim_w: The width of the image patches. Tensor of unpadded image heights.
image_placeholder_id: The id of the image placeholder token. image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]):
image_newline_id: The id of the image newline token. Tensor of unpadded image widths.
variable_sized: Whether to process images as variable-sized. image_placeholder_id (int):
The id of the image placeholder token. Comes from an associated tokenizer.
image_newline_id (int):
The id of the image newline token. Comes from an associated tokenizer.
variable_sized (bool):
Whether to process images as variable-sized.
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Size of the patches.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
patch_size = patch_size if patch_size is not None else self.patch_size
patch_height, patch_width = patch_size["height"], patch_size["width"]
# Only images that are present. # Only images that are present.
images: List[List[torch.Tensor]] = [] images: List[List[torch.Tensor]] = []
image_patches: List[List[torch.Tensor]] = [] batch_image_patches: List[List[torch.Tensor]] = []
# Image input ids for every subsequence, including ones with no image present. # Image input ids for every subsequence, including ones with no image present.
image_input_ids: List[List[torch.Tensor]] = [] batch_image_input_ids: List[List[torch.Tensor]] = []
for bi in range(image_input.shape[0]): for batch_index in range(image_input.shape[0]):
images.append([]) image_input_ids = []
image_input_ids.append([]) image_patches = []
image_patches.append([]) for subseq_index in range(image_input.shape[1]):
for si in range(image_input.shape[1]): if image_present[batch_index, subseq_index]:
if image_present[bi, si]: image = image_input[batch_index, subseq_index]
image = image_input[bi, si] image_height, image_width = image.shape[1], image.shape[2]
if variable_sized: if variable_sized:
# The min() is required here due to floating point issues: # The min() is required here due to floating point issues:
# math.ceil(torch.tensor(300).cuda() / 30) == 11 # math.ceil(torch.tensor(300).cuda() / 30) == 11
new_h = min( new_h = min(
image.shape[1], math.ceil(image_unpadded_h[bi, si] / image_patch_dim_h) * image_patch_dim_h image_height,
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
) )
new_w = min( new_w = min(
image.shape[2], math.ceil(image_unpadded_w[bi, si] / image_patch_dim_w) * image_patch_dim_w image_width,
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
) )
image = image[:, :new_h, :new_w] image = image[:, :new_h, :new_w]
images[bi].append(image) image_height, image_width = new_h, new_w
num_patches = self.get_num_patches(
img_h=image.shape[1], num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
img_w=image.shape[2], tensor_of_image_ids = torch.full(
patch_dim_h=image_patch_dim_h, [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
patch_dim_w=image_patch_dim_w,
) )
ids = torch.full([num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device) patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
patches = self.patchify_image( assert num_patches == patches.shape[0]
image=image.unsqueeze(0), patch_dim_h=image_patch_dim_h, patch_dim_w=image_patch_dim_w
).squeeze(0)
if variable_sized: if variable_sized:
# Now terminate each line with |NEWLINE|. # Now terminate each line with |NEWLINE|.
ids = ids.reshape(-1, new_w // image_patch_dim_w) tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
ids = torch.cat( newline_ids = torch.full(
[ [tensor_of_image_ids.shape[0], 1],
ids, image_newline_id,
torch.full( dtype=torch.int32,
[ids.shape[0], 1], image_newline_id, dtype=torch.int32, device=image_input.device device=image_input.device,
),
],
dim=1,
) )
ids = ids.reshape(-1) tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
image_input_ids[bi].append(ids) tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
image_patches[bi].append(patches)
images.append([image])
image_input_ids.append(tensor_of_image_ids)
image_patches.append(patches)
else: else:
image_input_ids[bi].append(torch.tensor([], dtype=torch.int32, device=image_input.device)) image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
batch_image_input_ids.append(image_input_ids)
batch_image_patches.append(image_patches)
# Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in # Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
# the stream. # the stream.
image_patch_indices_per_batch: List[List[torch.Tensor]] = [] image_patch_indices_per_batch: List[List[torch.Tensor]] = []
image_patch_indices_per_subsequence: List[List[torch.Tensor]] = [] image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
for bi in range(len(image_input_ids)):
image_patch_indices_per_batch.append([]) for sample_image_input_ids in batch_image_input_ids:
image_patch_indices_per_subsequence.append([])
index_offset = 0 index_offset = 0
for si in range(len(image_input_ids[bi])): per_batch_indices = []
per_subsequence_indices = []
for subseq_image_input_ids in sample_image_input_ids:
# Indices of image patches. # Indices of image patches.
num_patches = torch.count_nonzero(image_input_ids[bi][si] == image_placeholder_id) patches_mask = subseq_image_input_ids == image_placeholder_id
num_patches = torch.count_nonzero(patches_mask)
indices = torch.arange( indices = torch.arange(
num_patches, num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device
dtype=image_input_ids[bi][si].dtype,
device=image_input_ids[bi][si].device,
) )
# Place those indices in the image input ids token stream, with -1 representing non-index tokens. # Place those indices in the image input ids token stream, with -1 representing non-index tokens.
indices_in_stream_per_batch = torch.full_like(image_input_ids[bi][si], -1) indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1)
indices_in_stream_per_subsequence = torch.full_like(image_input_ids[bi][si], -1) indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1)
indices_in_stream_per_batch[ patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0]
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
] = (indices + index_offset)
indices_in_stream_per_subsequence[
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
] = indices
image_patch_indices_per_batch[bi].append(indices_in_stream_per_batch)
image_patch_indices_per_subsequence[bi].append(indices_in_stream_per_subsequence)
index_offset += num_patches
return {
"images": images,
"image_input_ids": image_input_ids,
"image_patches": image_patches,
"image_patch_indices_per_batch": image_patch_indices_per_batch,
"image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
}
def _scale_to_target_aspect_ratio(self, image: np.ndarray) -> np.ndarray: indices_in_stream_per_batch[patches_inds] = indices + index_offset
image_height, image_width, _ = image.shape indices_in_stream_per_subsequence[patches_inds] = indices
if image_width <= self.target_width and image_height <= self.target_height:
return image
height_scale_factor = self.target_height / image_height
width_scale_factor = self.target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
new_height = int(image_height * optimal_scale_factor) per_batch_indices.append(indices_in_stream_per_batch)
new_width = int(image_width * optimal_scale_factor) per_subsequence_indices.append(indices_in_stream_per_subsequence)
index_offset += num_patches
scaled_image = resize(image=image, size=(new_height, new_width))
return np.array(scaled_image)
def _pad_to_target_size(self, image: np.ndarray) -> np.ndarray:
image_height, image_width, _ = image.shape
padding_top = 0
padding_left = 0
padding_bottom = self.target_height - image_height
padding_right = self.target_width - image_width
padded_image = pad( image_patch_indices_per_batch.append(per_batch_indices)
image, image_patch_indices_per_subsequence.append(per_subsequence_indices)
((padding_top, padding_bottom), (padding_left, padding_right)),
mode=self.padding_mode, return FuyuBatchFeature(
constant_values=self.padding_value, data={
"images": images,
"image_input_ids": batch_image_input_ids,
"image_patches": batch_image_patches,
"image_patch_indices_per_batch": image_patch_indices_per_batch,
"image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
}
) )
return padded_image
def apply_transformation(self, image: Union[np.ndarray, PIL.Image.Image]) -> np.ndarray:
if isinstance(image, PIL.Image.Image):
image = to_numpy_array(image)
scaled_image = self._scale_to_target_aspect_ratio(image)
padded_image = self._pad_to_target_size(scaled_image)
normalized_padded_image = normalize(padded_image, 0.5, 0.5)
return normalized_padded_image
...@@ -257,8 +257,10 @@ class FuyuForCausalLM(FuyuPreTrainedModel): ...@@ -257,8 +257,10 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None: if image_patches is not None and past_key_values is None:
patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype)) patch_embeddings = [
patch_embeddings = patch_embeddings.to(inputs_embeds.device) self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
for patch in image_patches
]
inputs_embeds = self.gather_continuous_embeddings( inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds, word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings, continuous_embeddings=patch_embeddings,
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image/Text processor class for GIT
"""
import re import re
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...image_utils import (
ChannelDimension,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
to_numpy_array,
)
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...utils import is_torch_available, is_vision_available, logging from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
from ...utils import TensorType, is_torch_available, logging, requires_backends
if is_torch_available() and is_vision_available(): if is_torch_available():
from .image_processing_fuyu import FuyuImageProcessor from .image_processing_fuyu import FuyuBatchFeature
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
if is_torch_available(): if is_torch_available():
import torch import torch
BBOX_OPEN_STRING = "<0x00>" # <bbox>
BBOX_CLOSE_STRING = "<0x01>" # </bbox>
POINT_OPEN_STRING = "<0x02>" # <point>
POINT_CLOSE_STRING = "<0x03>" # </point>
TEXT_REPR_BBOX_OPEN = "<box>" TEXT_REPR_BBOX_OPEN = "<box>"
TEXT_REPR_BBOX_CLOSE = "</box>" TEXT_REPR_BBOX_CLOSE = "</box>"
TEXT_REPR_POINT_OPEN = "<point>" TEXT_REPR_POINT_OPEN = "<point>"
TEXT_REPR_POINT_CLOSE = "</point>" TEXT_REPR_POINT_CLOSE = "</point>"
TOKEN_BBOX_OPEN_STRING = BBOX_OPEN_STRING = "<0x00>" # <bbox> TOKEN_BBOX_OPEN_STRING = "<0x00>" # <bbox>
BBOX_CLOSE_STRING = "<0x01>" # </bbox> TOKEN_BBOX_CLOSE_STRING = "<0x01>" # </bbox>
TOKEN_BBOX_CLOSE_STRING = TOKEN_POINT_OPEN_STRING = POINT_OPEN_STRING = "<0x02>" # <point> TOKEN_POINT_OPEN_STRING = "<0x02>" # <point>
TOKEN_POINT_CLOSE_STRING = POINT_CLOSE_STRING = "<0x03>" # </point> TOKEN_POINT_CLOSE_STRING = "<0x03>" # </point>
BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa> BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>
...@@ -87,18 +92,16 @@ def construct_full_unpacked_stream( ...@@ -87,18 +92,16 @@ def construct_full_unpacked_stream(
all_bi_stream = [] all_bi_stream = []
for bi in range(batch_size): for batch_index in range(batch_size):
all_si_stream = [] all_si_stream = []
# First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence # First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence
# and append to lists. We use lists rather than tensors because each subsequence is variable-sized. # and append to lists. We use lists rather than tensors because each subsequence is variable-sized.
for si in range(num_sub_sequences): # TODO Remove this logic in a subsequent release since subsequences are not supported.
image_adjustment = image_tokens[bi][si] image_adjustment = image_tokens[batch_index][0]
si_stream = torch.cat([image_adjustment, input_stream[bi, si]], dim=0) subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0)
num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[bi][si] num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0]
all_si_stream.append(subsequence_stream[:num_real_tokens])
all_si_stream.append(si_stream[:num_real_tokens])
# Combine all subsequences for this batch entry. Still using a list because each batch entry is variable-sized.
all_bi_stream.append(torch.cat(all_si_stream, dim=0)) all_bi_stream.append(torch.cat(all_si_stream, dim=0))
return all_bi_stream return all_bi_stream
...@@ -137,7 +140,7 @@ def _segment_prompt_into_text_token_conversions(prompt: str) -> List: ...@@ -137,7 +140,7 @@ def _segment_prompt_into_text_token_conversions(prompt: str) -> List:
return prompt_text_list return prompt_text_list
def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenizer) -> List[int]: def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> List[int]:
""" """
This function transforms the prompt in the following fashion: This function transforms the prompt in the following fashion:
- <box> <point> and </box> </point> to their respective token mappings - <box> <point> and </box> </point> to their respective token mappings
...@@ -161,7 +164,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize ...@@ -161,7 +164,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
for elem in prompt_text_list: for elem in prompt_text_list:
if elem[1]: if elem[1]:
# This is a location, we need to tokenize it # This is a location, we need to tokenize it
within_tag_tokenized = _transform_within_tags(elem[0], transformed_image, tokenizer) within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer)
# Surround the text with the open and close tags # Surround the text with the open and close tags
transformed_prompt_tokens.extend(within_tag_tokenized) transformed_prompt_tokens.extend(within_tag_tokenized)
else: else:
...@@ -169,7 +172,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize ...@@ -169,7 +172,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
return transformed_prompt_tokens return transformed_prompt_tokens
def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]: def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> List[int]:
""" """
Given a bounding box of the fashion <box>1, 2, 3, 4</box> | <point>1, 2</point> This function is responsible for Given a bounding box of the fashion <box>1, 2, 3, 4</box> | <point>1, 2</point> This function is responsible for
converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas. converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas.
...@@ -188,16 +191,14 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int] ...@@ -188,16 +191,14 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
num_ints = [float(num.strip()) for num in num_int_strs] num_ints = [float(num.strip()) for num in num_int_strs]
# scale to transformed image siz # scale to transformed image siz
if len(num_ints) == 2: if len(num_ints) == 2:
num_ints_translated = scale_point_to_transformed_image( num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor)
x=num_ints[0], y=num_ints[1], transformed_image=transformed_image
)
elif len(num_ints) == 4: elif len(num_ints) == 4:
num_ints_translated = scale_bbox_to_transformed_image( num_ints_translated = scale_bbox_to_transformed_image(
top=num_ints[0], top=num_ints[0],
left=num_ints[1], left=num_ints[1],
bottom=num_ints[2], bottom=num_ints[2],
right=num_ints[3], right=num_ints[3],
transformed_image=transformed_image, scale_factor=scale_factor,
) )
else: else:
raise ValueError(f"Invalid number of ints: {len(num_ints)}") raise ValueError(f"Invalid number of ints: {len(num_ints)}")
...@@ -209,7 +210,7 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int] ...@@ -209,7 +210,7 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
def _tokenize_prompts_with_image_and_batch( def _tokenize_prompts_with_image_and_batch(
tokenizer, tokenizer,
prompts: List[List[str]], prompts: List[List[str]],
transformed_images: Optional[List[List["torch.Tensor"]]], scale_factors: Optional[List[List["torch.Tensor"]]],
max_tokens_to_generate: int, max_tokens_to_generate: int,
max_position_embeddings: int, max_position_embeddings: int,
add_BOS: bool, # Same issue with types as above add_BOS: bool, # Same issue with types as above
...@@ -223,13 +224,13 @@ def _tokenize_prompts_with_image_and_batch( ...@@ -223,13 +224,13 @@ def _tokenize_prompts_with_image_and_batch(
""" """
# If not tool use, tranform the coordinates while tokenizing # If not tool use, tranform the coordinates while tokenizing
if transformed_images is not None: if scale_factors is not None:
transformed_prompt_tokens = [] transformed_prompt_tokens = []
for prompt_seq, transformed_image_seq in zip(prompts, transformed_images): for prompt_seq, scale_factor_seq in zip(prompts, scale_factors):
transformed_prompt_tokens.append( transformed_prompt_tokens.append(
[ [
_transform_coordinates_and_tokenize(prompt, transformed_image, tokenizer) _transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer)
for prompt, transformed_image in zip(prompt_seq, transformed_image_seq) for prompt, scale_factor in zip(prompt_seq, scale_factor_seq)
] ]
) )
else: else:
...@@ -260,7 +261,7 @@ def _tokenize_prompts_with_image_and_batch( ...@@ -260,7 +261,7 @@ def _tokenize_prompts_with_image_and_batch(
# Number of tokens in the each sample of the batch. # Number of tokens in the each sample of the batch.
samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings) samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings)
if max_prompt_len + max_tokens_to_generate > max_position_embeddings: if max_prompt_len + max_tokens_to_generate > max_position_embeddings:
print( logger.warning(
f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}", f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.", f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
) )
...@@ -279,86 +280,30 @@ def _tokenize_prompts_with_image_and_batch( ...@@ -279,86 +280,30 @@ def _tokenize_prompts_with_image_and_batch(
return prompts_tokens_tensor, prompts_length_tensor return prompts_tokens_tensor, prompts_length_tensor
def original_to_transformed_h_coords(self, original_coords): # Simplified assuming self.crop_top = self.padding_top = 0
# apply crop def original_to_transformed_h_coords(original_coords, scale_h):
cropped_coords = ( return np.round(original_coords * scale_h).astype(np.int32)
self._clamp_coords(original_coords, min_value=self.crop_top, max_value=self.crop_bottom) - self.crop_top
)
# apply scale
scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_h / self.original_h)
# apply pad
return scaled_coords + self.padding_top
def original_to_transformed_w_coords(self, original_coords): # Simplified assuming self.crop_left = self.padding_left = 0
# apply crop def original_to_transformed_w_coords(original_coords, scale_w):
cropped_coords = ( return np.round(original_coords * scale_w).astype(np.int32)
self._clamp_coords(original_coords, min_value=self.crop_left, max_value=self.crop_right) - self.crop_left
)
# apply scale
scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_w / self.original_w)
# apply pad
return scaled_coords + self.padding_left
def scale_point_to_transformed_image(x: float, y: float) -> List[int]: def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> List[int]:
x_scaled = original_to_transformed_w_coords(np.array([x / 2]))[0] x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0]
y_scaled = original_to_transformed_h_coords(np.array([y / 2]))[0] y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0]
return [x_scaled, y_scaled] return [x_scaled, y_scaled]
def scale_bbox_to_transformed_image(top: float, left: float, bottom: float, right: float) -> List[int]: def scale_bbox_to_transformed_image(
top_scaled = original_to_transformed_w_coords(np.array([top / 2]))[0] top: float, left: float, bottom: float, right: float, scale_factor: float
left_scaled = original_to_transformed_h_coords(np.array([left / 2]))[0]
bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]))[0]
right_scaled = original_to_transformed_h_coords(np.array([right / 2]))[0]
return [top_scaled, left_scaled, bottom_scaled, right_scaled]
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]: ) -> List[int]:
""" top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0]
Get the maximum height and width across all images in a batch. left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0]
""" bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0]
if input_data_format is None: right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0]
input_data_format = infer_channel_dimension_format(images[0]) return [top_scaled, left_scaled, bottom_scaled, right_scaled]
if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
class FuyuProcessor(ProcessorMixin): class FuyuProcessor(ProcessorMixin):
...@@ -384,42 +329,148 @@ class FuyuProcessor(ProcessorMixin): ...@@ -384,42 +329,148 @@ class FuyuProcessor(ProcessorMixin):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_tokens_to_generate = 10 self.max_tokens_to_generate = 10
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it? self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
self.image_processor = FuyuImageProcessor() self.pad_token_id = 0
self.dummy_image_index = -1
def _process_images(self, images):
"""Utility function to preprocess the images and extract necessary information about original formats.""" def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
batch_images = [] max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
image_unpadded_heights = [] max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs)
image_unpadded_widths = []
batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []}
for image in images:
image = to_numpy_array(image) for entry in model_inputs:
if not is_scaled_image(image): for key, tensor in entry.items():
image = image / 255.0 if key == "input_ids":
channel_dimension = infer_channel_dimension_format(image, 3) num_padding_tokens = max_length_input_ids - tensor.shape[1]
if channel_dimension == ChannelDimension.FIRST: padded_input_ids = torch.cat(
width_index = 2 [
height_index = 1 torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long),
elif channel_dimension == ChannelDimension.LAST: tensor,
width_index = 1 ],
height_index = 0 dim=1,
)
image_unpadded_widths.append([image.shape[width_index]]) batched_inputs[key].append(padded_input_ids)
image_unpadded_heights.append([image.shape[height_index]])
attention_mask = torch.cat(
# Reproduct adept padding sampler [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)],
padded_image = self.image_processor.apply_transformation(image) dim=1,
)
tensor_img = torch.Tensor(padded_image).permute(2, 0, 1) batched_inputs["attention_mask"].append(attention_mask)
batch_images.append([tensor_img])
elif key == "image_patches":
return batch_images, torch.Tensor(image_unpadded_heights), torch.Tensor(image_unpadded_widths) # For image_patches, we don't pad but just append them to the list.
batched_inputs[key].append(tensor)
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
else: # for image_patches_indices
num_padding_indices = max_length_image_patch_indices - tensor.shape[1]
padded_indices = torch.cat(
[
torch.full(
(tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long
),
tensor,
],
dim=1,
)
batched_inputs[key].append(padded_indices)
batched_keys = ["input_ids", "image_patches_indices"]
if return_attention_mask:
batched_keys.append("attention_mask")
for key in batched_keys:
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
return batched_inputs
def get_sample_encoding(
self,
prompts,
scale_factors,
image_unpadded_heights,
image_unpadded_widths,
image_placeholder_id,
image_newline_id,
tensor_batch_images,
):
image_present = torch.ones(1, 1, 1)
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
image_input=tensor_batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
variable_sized=True,
)
# FIXME max_tokens_to_generate is embedded into this processor's call.
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
tokenizer=self.tokenizer,
prompts=prompts,
scale_factors=scale_factors,
max_tokens_to_generate=self.max_tokens_to_generate,
max_position_embeddings=self.max_position_embeddings,
add_BOS=True,
add_beginning_of_answer_token=True,
)
image_padded_unpacked_tokens = construct_full_unpacked_stream(
num_real_text_tokens=prompts_length,
input_stream=prompt_tokens,
image_tokens=model_image_input["image_input_ids"],
batch_size=1,
num_sub_sequences=self.subsequence_length,
)
# Construct inputs for image patch indices.
unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
num_real_text_tokens=prompts_length,
input_stream=torch.full_like(prompt_tokens, -1),
image_tokens=model_image_input["image_patch_indices_per_batch"],
batch_size=1,
num_sub_sequences=self.subsequence_length,
)
max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0]))
# Use same packing logic for the image patch indices.
image_patch_input_indices = full_unpacked_stream_to_tensor(
all_bi_tokens_to_place=[tokens_to_place],
full_unpacked_stream=unpacked_image_patch_indices_per_batch,
fill_value=-1,
batch_size=1,
new_seq_len=max_seq_len_batch,
offset=0,
)
image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]])
batch_encoding = {
"input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
"image_patches": image_patches_tensor,
"image_patches_indices": image_patch_input_indices,
}
return batch_encoding
def __call__(
self,
text=None,
images=None,
add_special_tokens: bool = True,
return_attention_mask: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> "FuyuBatchFeature":
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to
encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information. of the above two methods for more information.
...@@ -433,130 +484,211 @@ class FuyuProcessor(ProcessorMixin): ...@@ -433,130 +484,211 @@ class FuyuProcessor(ProcessorMixin):
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width. number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns: Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields: [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when - **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`.
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not - **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model.
`None`). - **attention_mask** -- List of indices specifying which tokens should be attended to by the model when
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. `return_attention_mask=True`.
""" """
requires_backends(self, ["torch"])
# --- Check input validity ---
if not return_attention_mask:
raise ValueError("`return_attention_mask=False` is not supported for this model.")
if text is None and images is None: if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.") raise ValueError("You have to specify either text or images. Both cannot be None.")
if text is not None and images is None:
logger.warning("You are processing a text with no associated image. Make sure it is intended.")
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding
if text is None and images is not None:
logger.warning("You are processing an image with no associated text. Make sure it is intended.")
prompts = [[""]]
if text is not None and images is not None: if text is not None and images is not None:
if isinstance(text, str): if isinstance(text, str):
prompts = [[text]] prompts = [[text]]
elif isinstance(text, list): elif isinstance(text, list):
prompts = [[text_seq] for text_seq in text] prompts = [[text_seq] for text_seq in text]
batch_images = []
if isinstance(images, PIL.Image.Image): # --- Preprocess images using self.image_processor ---
images = [images]
if isinstance(images, list): # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
batch_images, image_unpadded_heights, image_unpadded_widths = self._process_images(images) image_encoding = self.image_processor.preprocess(images, return_tensors="pt")
# image_unpadded_heights = image_unpadded_heights.unsqueeze(0) batch_images = image_encoding["images"]
# image_unpadded_widths = image_unpadded_widths.unsqueeze(0) image_unpadded_heights = image_encoding["image_unpadded_heights"]
else: image_unpadded_widths = image_encoding["image_unpadded_widths"]
raise ValueError("images must be a list of ndarrays or PIL Images to be processed.") scale_factors = image_encoding["image_scale_factors"]
self.subsequence_length = 1 # Each batch contains only one sequence.
# Note: the original adept code has a handling of image_unpadded_h and w, but it doesn't seem to hold self.batch_size = len(batch_images)
# when there are several different size subsequences per batch. The current implementation reflects
# that limitation and should be documented. # --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
#
self.subsequence_length = 1 # Each batch contains only one sequence. image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
self.batch_size = len(batch_images) image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
# FIXME max_tokens_to_generate is embedded into this processor's call. tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
tokenizer=self.tokenizer, # --- Use self.image_processor again to obtain the full token ids and batch inputs ---
prompts=prompts, all_encodings = []
transformed_images=batch_images,
max_tokens_to_generate=self.max_tokens_to_generate, for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip(
max_position_embeddings=self.max_position_embeddings, prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images
add_BOS=True, ):
add_beginning_of_answer_token=True, sample_encoding = self.get_sample_encoding(
) prompts=[prompt],
# same so far scale_factors=[scale_factor],
image_unpadded_heights=torch.tensor([image_unpadded_height]),
# This is 1 if there is an image per subsequence, else 0. [batch, 1, presence] image_unpadded_widths=torch.tensor([image_unpadded_width]),
# the remainder of current image processing logic assumes subsequence_size = 1.
# Here it is OK as the model cannot handle > 1 subsequences
# the image could be absent however and image presence should be inferred from user batch input
# hence this code assumes the images are present. Use an assert?
image_present = torch.ones(self.batch_size, 1, 1)
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
model_image_input = self.image_processor.process_images_for_model_input(
image_input=tensor_batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_patch_dim_h=30,
image_patch_dim_w=30,
image_placeholder_id=image_placeholder_id, image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id, image_newline_id=image_newline_id,
variable_sized=True, tensor_batch_images=tensor_batch_image.unsqueeze(0),
) )
all_encodings.append(sample_encoding)
batch_encoding = self._left_pad_inputs_with_attention_mask(
model_inputs=all_encodings, return_attention_mask=return_attention_mask
)
return FuyuBatchFeature(data=batch_encoding)
image_padded_unpacked_tokens = construct_full_unpacked_stream( def post_process_box_coordinates(self, outputs, target_sizes=None):
num_real_text_tokens=prompts_length, """
input_stream=prompt_tokens, Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
image_tokens=model_image_input["image_input_ids"], Coordinates will be returned in "box" format, with the following pattern:
batch_size=self.batch_size, `<box>top, left, bottom, right</box>`
num_sub_sequences=self.subsequence_length,
) Point coordinates are not supported yet.
# Construct inputs for image patch indices.
unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
num_real_text_tokens=prompts_length,
input_stream=torch.full_like(prompt_tokens, -1),
image_tokens=model_image_input["image_patch_indices_per_batch"],
batch_size=self.batch_size,
num_sub_sequences=self.subsequence_length,
)
max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
all_bi_tokens_to_place = []
for bi in range(self.batch_size):
tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[bi].shape[0]))
all_bi_tokens_to_place.append(tokens_to_place)
# Use same packing logic for the image patch indices.
image_patch_input_indices = full_unpacked_stream_to_tensor(
all_bi_tokens_to_place=all_bi_tokens_to_place,
full_unpacked_stream=unpacked_image_patch_indices_per_batch,
fill_value=-1,
batch_size=self.batch_size,
new_seq_len=max_seq_len_batch,
offset=0,
)
image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]).unsqueeze(1) Args:
return { outputs ([`GenerateOutput`]):
"input_ids": image_padded_unpacked_tokens[0].unsqueeze(0), Raw outputs from `generate`.
"image_patches": image_patches_tensor[0][0].unsqueeze(0), target_sizes (`torch.Tensor`, *optional*):
"image_patches_indices": image_patch_input_indices, Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
} the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left
to None, coordinates will not be rescaled.
Returns:
`GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with
boxed and possible rescaled coordinates.
"""
def scale_factor_to_fit(original_size, target_size=None):
height, width = original_size
if target_size is None:
max_height = self.image_processor.size["height"]
max_width = self.image_processor.size["width"]
else:
max_height, max_width = target_size
if width <= max_width and height <= max_height:
return 1.0
return min(max_height / height, max_width / width)
def find_delimiters_pair(tokens, start_token, end_token):
start_id = self.tokenizer.convert_tokens_to_ids(start_token)
end_id = self.tokenizer.convert_tokens_to_ids(end_token)
starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0]
ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0]
if torch.any(starting_positions) and torch.any(ending_positions):
return (starting_positions[0], ending_positions[0])
return (None, None)
def tokens_to_boxes(tokens, original_size):
while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != (
None,
None,
):
start, end = pair
if end != start + 5:
continue
# Retrieve transformed coordinates from tokens
coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
# Scale back to original image size and multiply by 2
scale = scale_factor_to_fit(original_size)
top, left, bottom, right = [2 * int(float(c) / scale) for c in coords]
# Replace the IDs so they get detokenized right
replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}"
replacement = self.tokenizer.tokenize(replacement)[1:]
replacement = self.tokenizer.convert_tokens_to_ids(replacement)
replacement = torch.tensor(replacement).to(tokens)
tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
return tokens
def tokens_to_points(tokens, original_size):
while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != (
None,
None,
):
start, end = pair
if end != start + 3:
continue
# Retrieve transformed coordinates from tokens
coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
# Scale back to original image size and multiply by 2
scale = scale_factor_to_fit(original_size)
x, y = [2 * int(float(c) / scale) for c in coords]
# Replace the IDs so they get detokenized right
replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}"
replacement = self.tokenizer.tokenize(replacement)[1:]
replacement = self.tokenizer.convert_tokens_to_ids(replacement)
replacement = torch.tensor(replacement).to(tokens)
tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
return tokens
if target_sizes is None:
target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs)
elif target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
if len(outputs) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
results = []
for seq, size in zip(outputs, target_sizes):
seq = tokens_to_boxes(seq, size)
seq = tokens_to_points(seq, size)
results.append(seq)
return results
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information. refer to the docstring of this method for more information.
""" """
return self.tokenizer.batch_decode(*args, **kwargs) return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs): def decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
...@@ -24,7 +24,8 @@ if is_vision_available(): ...@@ -24,7 +24,8 @@ if is_vision_available():
@require_torchvision @require_torchvision
class TestFuyuImageProcessor(unittest.TestCase): class TestFuyuImageProcessor(unittest.TestCase):
def setUp(self): def setUp(self):
self.processor = FuyuImageProcessor(target_height=160, target_width=320, padding_value=1.0) self.size = {"height": 160, "width": 320}
self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0)
self.batch_size = 3 self.batch_size = 3
self.channels = 3 self.channels = 3
self.height = 300 self.height = 300
...@@ -38,29 +39,25 @@ class TestFuyuImageProcessor(unittest.TestCase): ...@@ -38,29 +39,25 @@ class TestFuyuImageProcessor(unittest.TestCase):
self.sample_image_pil = Image.fromarray(self.sample_image) self.sample_image_pil = Image.fromarray(self.sample_image)
def test_patches(self): def test_patches(self):
expected_num_patches = self.processor.get_num_patches( expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width)
img_h=self.height, img_w=self.width, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
)
patches_final = self.processor.patchify_image( patches_final = self.processor.patchify_image(image=self.image_input)
image=self.image_input, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
)
assert ( assert (
patches_final.shape[1] == expected_num_patches patches_final.shape[1] == expected_num_patches
), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}." ), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}."
def test_scale_to_target_aspect_ratio(self): def test_scale_to_target_aspect_ratio(self):
# (h:450, w:210) fitting (160, 320) -> (160, 210*160/450) # (h:450, w:210) fitting (160, 320) -> (160, 210*160/450)
scaled_image = self.processor._scale_to_target_aspect_ratio(self.sample_image) scaled_image = self.processor.resize(self.sample_image, size=self.size)
self.assertEqual(scaled_image.shape[0], 160) self.assertEqual(scaled_image.shape[0], 160)
self.assertEqual(scaled_image.shape[1], 74) self.assertEqual(scaled_image.shape[1], 74)
def test_apply_transformation_numpy(self): def test_apply_transformation_numpy(self):
transformed_image = self.processor.apply_transformation(self.sample_image) transformed_image = self.processor.preprocess(self.sample_image).images[0][0]
self.assertEqual(transformed_image.shape[0], 160) self.assertEqual(transformed_image.shape[1], 160)
self.assertEqual(transformed_image.shape[1], 320) self.assertEqual(transformed_image.shape[2], 320)
def test_apply_transformation_pil(self): def test_apply_transformation_pil(self):
transformed_image = self.processor.apply_transformation(self.sample_image_pil) transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0]
self.assertEqual(transformed_image.shape[0], 160) self.assertEqual(transformed_image.shape[1], 160)
self.assertEqual(transformed_image.shape[1], 320) self.assertEqual(transformed_image.shape[2], 320)
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import requests import requests
from transformers import AutoTokenizer, FuyuConfig, is_torch_available, is_vision_available from transformers import FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
from ...test_modeling_common import ids_tensor, random_attention_mask from ...test_modeling_common import ids_tensor, random_attention_mask
...@@ -14,7 +14,7 @@ if is_vision_available(): ...@@ -14,7 +14,7 @@ if is_vision_available():
if is_torch_available() and is_vision_available(): if is_torch_available() and is_vision_available():
from transformers import FuyuImageProcessor, FuyuProcessor from transformers import FuyuProcessor
if is_torch_available(): if is_torch_available():
...@@ -267,11 +267,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) ...@@ -267,11 +267,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else () all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
def setUp(self): def setUp(self):
self.pretrained_model_name = "huggingface/new_model_release_weights" self.pretrained_model_name = "adept/fuyu-8b"
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name) self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name)
image_processor = FuyuImageProcessor()
self.processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name) self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
self.bus_image_url = ( self.bus_image_url = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
...@@ -280,9 +277,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) ...@@ -280,9 +277,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
@slow @slow
def test_model_8b_chat_greedy_generation_bus_captioning(self): def test_model_8b_chat_greedy_generation_bus_captioning(self):
EXPECTED_TEXT_COMPLETION = """A bus parked on the side of a road.|ENDOFTEXT|""" EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|"""
text_prompt_coco_captioning = "Generate a coco-style caption.\n" text_prompt_coco_captioning = "Generate a coco-style caption.\n"
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil) model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10) generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens) text = self.processor.tokenizer.batch_decode(generated_tokens)
...@@ -297,7 +293,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) ...@@ -297,7 +293,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
""" """
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_model_8b_chat_greedy_generation_bus_color(self): def test_model_8b_chat_greedy_generation_bus_color(self):
EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|" EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|"
text_prompt_bus_color = "What color is the bus?\n" text_prompt_bus_color = "What color is the bus?\n"
...@@ -314,7 +310,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) ...@@ -314,7 +310,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence) self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence)
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_model_8b_chat_greedy_generation_chart_vqa(self): def test_model_8b_chat_greedy_generation_chart_vqa(self):
# fmt: off # fmt: off
EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",] EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",]
...@@ -340,7 +336,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) ...@@ -340,7 +336,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
self.assertEqual(expected_text_completion, clean_sequence) self.assertEqual(expected_text_completion, clean_sequence)
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_model_8b_chat_greedy_generation_bounding_box(self): def test_model_8b_chat_greedy_generation_bounding_box(self):
EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|" EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|"
text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231 text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231
......
...@@ -26,16 +26,14 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here ...@@ -26,16 +26,14 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
""" """ """ """
def setUp(self): def setUp(self):
pretrained_model_name = "huggingface/pre_release_model" pretrained_model_name = "adept/fuyu-8b"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
image_processor = FuyuImageProcessor() self.image_processor = FuyuImageProcessor()
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
text_prompt = "Generate a coco-style caption.\\n" self.text_prompt = "Generate a coco-style caption.\\n"
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
self.one_image_bus_model_inputs = processor(text=text_prompt, images=bus_image_pil)
def test_fuyu_processing(self): def test_fuyu_processing(self):
""" """
...@@ -44,11 +42,119 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here ...@@ -44,11 +42,119 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
# fmt: off # fmt: off
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64) EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64) EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil)
# fmt: on
torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS)
torch.testing.assert_close(one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
def test_fuyu_processing_no_image(self):
"""
Test to check processor works with just text input
"""
processor_outputs = self.processor(text=self.text_prompt)
tokenizer_outputs = self.tokenizer(self.text_prompt)
self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"])
def test_fuyu_processing_no_text(self):
"""
Test to check processor works with just image input
"""
# fmt: off
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26,
27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133,
134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160,
161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200,
201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214,
215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227,
228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,
-1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267,
268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281,
282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]).to(torch.int64)
# fmt: on
processor_outputs = self.processor(images=self.bus_image_pil)
self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all())
def test_fuyu_processing_multiple_image_sample(self):
"""
Test to check processor works with multiple image inputs for a single text input
"""
# fmt: off
SINGLE_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
SINGLE_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
SINGLE_RESIZED_IMAGE_PATCH_INPUTS = torch.Tensor([[ 0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9, 10, 11, -1, 12, 13, 14, -1, 15, 16, 17, -1, 18, 19, 20, -1, 21, 22, 23, -1, 24, 25, 26, -1, 27, 28, 29, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])
SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[ 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122]])
# fmt: on # fmt: on
torch.testing.assert_close(
self.one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS # Batch of two images - equally sized
images = [self.bus_image_pil, self.bus_image_pil]
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
self.assertTrue(
(
processor_outputs["image_patches_indices"]
== torch.cat([SINGLE_IMAGE_PATCH_INPUTS, SINGLE_IMAGE_PATCH_INPUTS], dim=0)
).all()
)
self.assertTrue(
(
processor_outputs["input_ids"]
== torch.cat([SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, SINGLE_PADDED_UNPACKED_TOKEN_INPUTS], dim=0)
).all()
) )
torch.testing.assert_close(self.one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
# Processes single images with different sizes as expected
images = [self.bus_image_pil]
processor_outputs = self.processor(text=self.text_prompt, images=images)
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all())
self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all())
images = [self.bus_image_pil.resize((64, 300))]
processor_outputs = self.processor(text=self.text_prompt, images=images)
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all())
self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all())
# Batch of two images - different sizes. Left-pads the smaller image inputs
images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))]
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1]
padded_single_resized_image_patch = torch.cat(
[torch.ones([1, padding_len_patch]) * -1, SINGLE_RESIZED_IMAGE_PATCH_INPUTS], dim=1
)
expected_image_patch_inputs = torch.cat([SINGLE_IMAGE_PATCH_INPUTS, padded_single_resized_image_patch], dim=0)
padding_len_token = (
SINGLE_PADDED_UNPACKED_TOKEN_INPUTS.shape[1] - SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS.shape[1]
)
padded_single_resized_padded_unpacked_token_inputs = torch.cat(
[torch.zeros([1, padding_len_token]), SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS], dim=1
)
expected_padded_unpacked_token_inputs = torch.cat(
[SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, padded_single_resized_padded_unpacked_token_inputs], dim=0
)
self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all())
self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all())
@require_torch @require_torch
...@@ -97,7 +203,6 @@ class TestProcessImagesForModelInput(unittest.TestCase): ...@@ -97,7 +203,6 @@ class TestProcessImagesForModelInput(unittest.TestCase):
""" """
Adding a mix of present and absent images. Adding a mix of present and absent images.
""" """
self.image_processor = FuyuImageProcessor()
self.image_input = torch.randn([1, 1, 3, 64, 64]) self.image_input = torch.randn([1, 1, 3, 64, 64])
self.image_present = torch.tensor([[1]]) self.image_present = torch.tensor([[1]])
...@@ -108,19 +213,19 @@ class TestProcessImagesForModelInput(unittest.TestCase): ...@@ -108,19 +213,19 @@ class TestProcessImagesForModelInput(unittest.TestCase):
self.image_placeholder_id = 999 self.image_placeholder_id = 999
self.image_newline_id = 888 self.image_newline_id = 888
self.variable_sized = True self.variable_sized = True
self.image_processor = FuyuImageProcessor(
patch_size={"height": self.image_patch_dim_h, "width": self.image_patch_dim_w}
)
def test_process_images_for_model_input_fixed_sized(self): def test_process_images_for_model_input_fixed_sized(self):
self.variable_sized = False self.variable_sized = False
result = self.image_processor.process_images_for_model_input( result = self.image_processor.preprocess_with_tokenizer_info(
image_input=self.image_input, image_input=self.image_input,
image_present=self.image_present, image_present=self.image_present,
image_unpadded_h=self.image_unpadded_h, image_unpadded_h=self.image_unpadded_h,
image_unpadded_w=self.image_unpadded_w, image_unpadded_w=self.image_unpadded_w,
image_patch_dim_h=self.image_patch_dim_h,
image_patch_dim_w=self.image_patch_dim_w,
image_placeholder_id=self.image_placeholder_id, image_placeholder_id=self.image_placeholder_id,
image_newline_id=self.image_newline_id, image_newline_id=self.image_newline_id,
variable_sized=self.variable_sized, variable_sized=self.variable_sized,
) )
print(result["images"][0][0])
self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64])) self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment