"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d78e78a0e44da2b74e9d98d608ccd4e74b311cd9"
Unverified Commit d18a1cba authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Accept batched tensor of images as input to image processor (#21144)

* Accept a batched tensor of images as input

* Add to all image processors

* Update oneformer
parent 6f3faf38
...@@ -91,6 +91,45 @@ def is_batched(img): ...@@ -91,6 +91,45 @@ def is_batched(img):
return False return False
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
"""
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a batch of images, it is converted to a list of images.
Args:
images (`ImageInput`):
Image of images to turn into a list of images.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for a single input image. If the input image has a different number of
dimensions, an error is raised.
"""
if is_batched(images):
return images
# Either the input is a single image, in which case we create a list of length 1
if isinstance(images, PIL.Image.Image):
# PIL images are never batched
return [images]
if is_valid_image(images):
if images.ndim == expected_ndims + 1:
# Batch of images
images = [image for image in images]
elif images.ndim == expected_ndims:
# Single image
images = [images]
else:
raise ValueError(
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
f" {images.ndim} dimensions."
)
return images
raise ValueError(
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
f"jax.ndarray, but got {type(images)}."
)
def to_numpy_array(img) -> np.ndarray: def to_numpy_array(img) -> np.ndarray:
if not is_valid_image(img): if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}") raise ValueError(f"Invalid image type: {type(img)}")
......
...@@ -30,7 +30,7 @@ from ...image_utils import ( ...@@ -30,7 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -438,9 +438,9 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -438,9 +438,9 @@ class BeitImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
if not is_batched(images): images = make_list_of_images(images)
images = [images] if segmentation_maps is not None:
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -30,7 +30,14 @@ from ...image_transforms import ( ...@@ -30,7 +30,14 @@ from ...image_transforms import (
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging from ...utils import logging
from ...utils.import_utils import is_vision_available from ...utils.import_utils import is_vision_available
...@@ -286,8 +293,7 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -286,8 +293,7 @@ class BitImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -29,7 +29,7 @@ from ...image_utils import ( ...@@ -29,7 +29,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -247,8 +247,7 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -247,8 +247,7 @@ class BlipImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -30,7 +30,14 @@ from ...image_transforms import ( ...@@ -30,7 +30,14 @@ from ...image_transforms import (
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging from ...utils import logging
from ...utils.import_utils import is_vision_available from ...utils.import_utils import is_vision_available
...@@ -284,8 +291,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -284,8 +291,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -30,7 +30,14 @@ from ...image_transforms import ( ...@@ -30,7 +30,14 @@ from ...image_transforms import (
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging from ...utils import logging
from ...utils.import_utils import is_vision_available from ...utils.import_utils import is_vision_available
...@@ -286,8 +293,7 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -286,8 +293,7 @@ class CLIPImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -44,7 +44,7 @@ from transformers.image_utils import ( ...@@ -44,7 +44,7 @@ from transformers.image_utils import (
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_coco_detection_annotations, valid_coco_detection_annotations,
valid_coco_panoptic_annotations, valid_coco_panoptic_annotations,
...@@ -1172,9 +1172,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1172,9 +1172,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
if do_normalize is not None and (image_mean is None or image_std is None): if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.") raise ValueError("Image mean and std must be specified if do_normalize is True.")
if not is_batched(images): images = make_list_of_images(images)
images = [images] if annotations is not None and isinstance(annotations[0], dict):
annotations = [annotations] if annotations is not None else None annotations = [annotations]
if annotations is not None and len(images) != len(annotations): if annotations is not None and len(images) != len(annotations):
raise ValueError( raise ValueError(
......
...@@ -36,7 +36,7 @@ from ...image_utils import ( ...@@ -36,7 +36,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -272,8 +272,7 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -272,8 +272,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -44,7 +44,7 @@ from transformers.image_utils import ( ...@@ -44,7 +44,7 @@ from transformers.image_utils import (
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_coco_detection_annotations, valid_coco_detection_annotations,
valid_coco_panoptic_annotations, valid_coco_panoptic_annotations,
...@@ -1170,9 +1170,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1170,9 +1170,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
if do_normalize is not None and (image_mean is None or image_std is None): if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.") raise ValueError("Image mean and std must be specified if do_normalize is True.")
if not is_batched(images): images = make_list_of_images(images)
images = [images] if annotations is not None and isinstance(annotations[0], dict):
annotations = [annotations] if annotations is not None else None annotations = [annotations]
if annotations is not None and len(images) != len(annotations): if annotations is not None and len(images) != len(annotations):
raise ValueError( raise ValueError(
......
...@@ -29,7 +29,7 @@ from ...image_utils import ( ...@@ -29,7 +29,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -276,8 +276,7 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -276,8 +276,7 @@ class DeiTImageProcessor(BaseImageProcessor):
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -43,7 +43,7 @@ from transformers.image_utils import ( ...@@ -43,7 +43,7 @@ from transformers.image_utils import (
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_coco_detection_annotations, valid_coco_detection_annotations,
valid_coco_panoptic_annotations, valid_coco_panoptic_annotations,
...@@ -1138,9 +1138,9 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1138,9 +1138,9 @@ class DetrImageProcessor(BaseImageProcessor):
if do_normalize is not None and (image_mean is None or image_std is None): if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.") raise ValueError("Image mean and std must be specified if do_normalize is True.")
if not is_batched(images): images = make_list_of_images(images)
images = [images] if annotations is not None and isinstance(annotations[0], dict):
annotations = [annotations] if annotations is not None else None annotations = [annotations]
if annotations is not None and len(images) != len(annotations): if annotations is not None and len(images) != len(annotations):
raise ValueError( raise ValueError(
......
...@@ -34,7 +34,7 @@ from ...image_utils import ( ...@@ -34,7 +34,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -396,8 +396,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -396,8 +396,7 @@ class DonutImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -31,9 +31,9 @@ from ...image_utils import ( ...@@ -31,9 +31,9 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
is_batched,
is_torch_available, is_torch_available,
is_torch_tensor, is_torch_tensor,
make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -308,8 +308,7 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -308,8 +308,7 @@ class DPTImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -26,7 +26,14 @@ from transformers.utils.generic import TensorType ...@@ -26,7 +26,14 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging from ...utils import logging
...@@ -647,8 +654,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -647,8 +654,7 @@ class FlavaImageProcessor(BaseImageProcessor):
codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -24,7 +24,7 @@ from transformers.utils.generic import TensorType ...@@ -24,7 +24,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import rescale, resize, to_channel_dimension_format from ...image_transforms import rescale, resize, to_channel_dimension_format
from ...image_utils import ChannelDimension, get_image_size, is_batched, to_numpy_array, valid_images from ...image_utils import ChannelDimension, get_image_size, make_list_of_images, to_numpy_array, valid_images
from ...utils import logging from ...utils import logging
...@@ -166,8 +166,7 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -166,8 +166,7 @@ class GLPNImageProcessor(BaseImageProcessor):
if do_resize and size_divisor is None: if do_resize and size_divisor is None:
raise ValueError("size_divisor is required for resizing") raise ValueError("size_divisor is required for resizing")
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError("Invalid image(s)") raise ValueError("Invalid image(s)")
......
...@@ -23,7 +23,14 @@ from transformers.utils.generic import TensorType ...@@ -23,7 +23,14 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import rescale, resize, to_channel_dimension_format from ...image_transforms import rescale, resize, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging from ...utils import logging
...@@ -196,8 +203,7 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -196,8 +203,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
clusters = clusters if clusters is not None else self.clusters clusters = clusters if clusters is not None else self.clusters
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -28,7 +28,7 @@ from ...image_utils import ( ...@@ -28,7 +28,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -230,8 +230,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -230,8 +230,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -30,7 +30,7 @@ from ...image_utils import ( ...@@ -30,7 +30,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -320,8 +320,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -320,8 +320,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -35,7 +35,7 @@ from ...image_utils import ( ...@@ -35,7 +35,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
is_batched, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
...@@ -303,8 +303,7 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -303,8 +303,7 @@ class LevitImageProcessor(BaseImageProcessor):
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): images = make_list_of_images(images)
images = [images]
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
......
...@@ -37,7 +37,7 @@ from transformers.image_utils import ( ...@@ -37,7 +37,7 @@ from transformers.image_utils import (
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_batched, make_list_of_images,
valid_images, valid_images,
) )
from transformers.utils import ( from transformers.utils import (
...@@ -717,9 +717,9 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -717,9 +717,9 @@ class MaskFormerImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
if not is_batched(images): images = make_list_of_images(images)
images = [images] if segmentation_maps is not None:
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if segmentation_maps is not None and len(images) != len(segmentation_maps): if segmentation_maps is not None and len(images) != len(segmentation_maps):
raise ValueError("Images and segmentation maps must have the same length.") raise ValueError("Images and segmentation maps must have the same length.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment