Unverified Commit 76924384 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Vilt - use image_transforms pad (#20780)

Use image_transforms pad
parent ecd7de3d
...@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict: ...@@ -189,6 +189,7 @@ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
return norm_annotation return norm_annotation
# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]: def max_across_indices(values: Iterable[Any]) -> List[Any]:
""" """
Return the maximum value across all indices of an iterable of values. Return the maximum value across all indices of an iterable of values.
...@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -196,6 +197,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return [max(values_i) for values_i in zip(*values)] return [max(values_i) for values_i in zip(*values)]
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(images: List[np.ndarray]) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
...@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]: ...@@ -211,6 +213,7 @@ def get_max_height_width(images: List[np.ndarray]) -> List[int]:
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
......
...@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available ...@@ -23,7 +23,7 @@ from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format from ...image_transforms import PaddingMode, normalize, pad, rescale, resize, to_channel_dimension_format
from ...image_utils import ( from ...image_utils import (
IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_STANDARD_STD,
...@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -53,46 +53,6 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return [max(values_i) for values_i in zip(*values)] return [max(values_i) for values_i in zip(*values)]
def pad(
image: np.ndarray,
output_size: Tuple[int, int],
input_channel_dimension: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if input_channel_dimension is None:
input_channel_dimension = infer_channel_dimension_format(image)
output_height, output_width = output_size
input_height, input_width = get_image_size(image)
pad_bottom = output_height - input_height
pad_right = output_width - input_width
if input_channel_dimension == ChannelDimension.FIRST:
padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
elif input_channel_dimension == ChannelDimension.LAST:
padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
if data_format is not None:
padded_image = to_channel_dimension_format(padded_image, data_format)
return padded_image
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -109,7 +69,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
return mask return mask
def get_max_dimensions(images: List[np.ndarray]) -> List[int]: def get_max_height_width(images: List[np.ndarray]) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
...@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -304,6 +264,27 @@ class ViltImageProcessor(BaseImageProcessor):
""" """
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
)
return padded_image
def pad( def pad(
self, self,
images: List[np.ndarray], images: List[np.ndarray],
...@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -330,8 +311,10 @@ class ViltImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
pad_size = get_max_dimensions(images) pad_size = get_max_height_width(images)
padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images] padded_images = [
self._pad_image(image=image, output_size=pad_size, data_format=data_format) for image in images
]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment