Unverified Commit d18a1cba authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Accept batched tensor of images as input to image processor (#21144)

* Accept a batched tensor of images as input

* Add to all image processors

* Update oneformer
parent 6f3faf38
......@@ -35,7 +35,7 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -288,8 +288,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -36,7 +36,7 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -295,8 +295,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -28,7 +28,7 @@ from ...image_utils import (
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -284,8 +284,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -38,7 +38,7 @@ from transformers.image_utils import (
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
make_list_of_images,
valid_images,
)
from transformers.utils import (
......@@ -676,9 +676,9 @@ class OneFormerImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if segmentation_maps is not None and len(images) != len(segmentation_maps):
raise ValueError("Images and segmentation maps must have the same length.")
......
......@@ -29,7 +29,13 @@ from transformers.image_transforms import (
to_channel_dimension_format,
to_numpy_array,
)
from transformers.image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, valid_images
from transformers.image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
valid_images,
)
from transformers.utils import TensorType, is_torch_available, logging
......@@ -300,8 +306,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -30,7 +30,7 @@ from ...image_utils import (
ImageInput,
PILImageResampling,
get_image_size,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -289,8 +289,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -36,7 +36,7 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -339,8 +339,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -30,7 +30,7 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -385,9 +385,9 @@ class SegformerImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if not valid_images(images):
raise ValueError(
......
......@@ -22,7 +22,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import get_image_size, pad, rescale, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, is_batched, to_numpy_array, valid_images
from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images
from ...utils import logging
......@@ -148,8 +148,7 @@ class Swin2SRImageProcessor(BaseImageProcessor):
do_pad = do_pad if do_pad is not None else self.do_pad
pad_size = pad_size if pad_size is not None else self.pad_size
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -32,7 +32,7 @@ from ...image_utils import (
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -441,8 +441,7 @@ class ViltImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -28,7 +28,7 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
make_list_of_images,
to_numpy_array,
valid_images,
)
......@@ -243,8 +243,7 @@ class ViTImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size_dict = get_size_dict(size)
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -30,7 +30,14 @@ from ...image_transforms import (
resize,
to_channel_dimension_format,
)
from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils.import_utils import is_vision_available
......@@ -286,8 +293,7 @@ class ViTHybridImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if not is_batched(images):
images = [images]
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
......
......@@ -42,7 +42,7 @@ from transformers.image_utils import (
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
make_list_of_images,
to_numpy_array,
valid_coco_detection_annotations,
valid_coco_panoptic_annotations,
......@@ -1038,9 +1038,9 @@ class YolosImageProcessor(BaseImageProcessor):
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
if not is_batched(images):
images = [images]
annotations = [annotations] if annotations is not None else None
images = make_list_of_images(images)
if annotations is not None and isinstance(annotations[0], dict):
annotations = [annotations]
if annotations is not None and len(images) != len(annotations):
raise ValueError(
......
......@@ -20,7 +20,7 @@ import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images
from transformers.testing_utils import require_torch, require_vision
......@@ -102,6 +102,58 @@ class ImageFeatureExtractionTester(unittest.TestCase):
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
def test_make_list_of_images_numpy(self):
# Test a single image is converted to a list of 1 image
images = np.random.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = np.random.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is not modified
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test batched masks with no channel dimension are converted to a list of masks
masks = np.random.randint(0, 2, (4, 16, 32))
masks_list = make_list_of_images(masks, expected_ndims=2)
self.assertEqual(len(masks_list), 4)
self.assertTrue(np.array_equal(masks_list[0], masks[0]))
self.assertIsInstance(masks_list, list)
@require_torch
def test_make_list_of_images_torch(self):
# Test a single image is converted to a list of 1 image
images = torch.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = torch.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is left unchanged
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
@require_torch
def test_conversion_torch_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment