Unverified Commit 76924384 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Vilt - use image_transforms pad (#20780)

Use image_transforms pad
parent ecd7de3d
......@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
return norm_annotation
# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
......@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return [max(values_i) for values_i in zip(*values)]
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
......@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]:
return (max_height, max_width)
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
......
......@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_transforms import PaddingMode, normalize, pad, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
......@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return [max(values_i) for values_i in zip(*values)]
def pad(
image: np.ndarray,
output_size: Tuple[int, int],
input_channel_dimension: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if input_channel_dimension is None:
input_channel_dimension = infer_channel_dimension_format(image)
output_height, output_width = output_size
input_height, input_width = get_image_size(image)
pad_bottom = output_height - input_height
pad_right = output_width - input_width
if input_channel_dimension == ChannelDimension.FIRST:
padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
elif input_channel_dimension == ChannelDimension.LAST:
padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
if data_format is not None:
padded_image = to_channel_dimension_format(padded_image, data_format)
return padded_image
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
......@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
return mask
def get_max_dimensions(images: List[np.ndarray]) -> List[int]:
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
......@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor):
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
)
return padded_image
def pad(
self,
images: List[np.ndarray],
......@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
pad_size = get_max_dimensions(images)
padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images]
pad_size = get_max_height_width(images)
padded_images = [
self._pad_image(image=image, output_size=pad_size, data_format=data_format) for image in images
]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment