Unverified Commit 41d56ea6 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Refactor image processor testers (#25450)

* Refactor image processor test mixin

- Move test_call_numpy, test_call_pytorch, test_call_pil to mixin
- Rename mixin to reflect handling of logic more than saving
- Add prepare_image_inputs, expected_image_outputs for tests

* Fix for oneformer
parent 454957c9
...@@ -18,12 +18,10 @@ import json ...@@ -18,12 +18,10 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingSavingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available(): if is_torch_available():
...@@ -111,10 +109,25 @@ class YolosImageProcessingTester(unittest.TestCase): ...@@ -111,10 +109,25 @@ class YolosImageProcessingTester(unittest.TestCase):
return expected_height, expected_width return expected_height, expected_width
def expected_output_image_shape(self, images):
height, width = self.get_expected_values(images, batched=True)
return self.num_channels, height, width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch @require_torch
@require_vision @require_vision
class YolosImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase): class YolosImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = YolosImageProcessor if is_vision_available() else None image_processing_class = YolosImageProcessor if is_vision_available() else None
def setUp(self): def setUp(self):
...@@ -143,113 +156,12 @@ class YolosImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase ...@@ -143,113 +156,12 @@ class YolosImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(image_processor.do_pad, False) self.assertEqual(image_processor.do_pad, False)
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
def test_equivalence_padding(self): def test_equivalence_padding(self):
# Initialize image_processings # Initialize image_processings
image_processing_1 = self.image_processing_class(**self.image_processor_dict) image_processing_1 = self.image_processing_class(**self.image_processor_dict)
image_processing_2 = self.image_processing_class(do_resize=False, do_normalize=False, do_rescale=False) image_processing_2 = self.image_processing_class(do_resize=False, do_normalize=False, do_rescale=False)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image in image_inputs: for image in image_inputs:
self.assertIsInstance(image, torch.Tensor) self.assertIsInstance(image, torch.Tensor)
......
...@@ -29,7 +29,16 @@ if is_vision_available(): ...@@ -29,7 +29,16 @@ if is_vision_available():
from PIL import Image from PIL import Image
def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify=False, torchify=False): def prepare_image_inputs(
batch_size,
min_resolution,
max_resolution,
num_channels,
size_divisor=None,
equal_resolution=False,
numpify=False,
torchify=False,
):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True. or a list of PyTorch tensors if one specifies torchify=True.
...@@ -39,19 +48,16 @@ def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify ...@@ -39,19 +48,16 @@ def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
image_inputs = [] image_inputs = []
for i in range(image_processor_tester.batch_size): for i in range(batch_size):
if equal_resolution: if equal_resolution:
width = height = image_processor_tester.max_resolution width = height = max_resolution
else: else:
# To avoid getting image width/height 0 # To avoid getting image width/height 0
min_resolution = image_processor_tester.min_resolution if size_divisor is not None:
if getattr(image_processor_tester, "size_divisor", None):
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor` # If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
min_resolution = max(image_processor_tester.size_divisor, min_resolution) min_resolution = max(size_divisor, min_resolution)
width, height = np.random.choice(np.arange(min_resolution, image_processor_tester.max_resolution), 2) width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
image_inputs.append( image_inputs.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
np.random.randint(255, size=(image_processor_tester.num_channels, width, height), dtype=np.uint8)
)
if not numpify and not torchify: if not numpify and not torchify:
# PIL expects the channel dimension as last dimension # PIL expects the channel dimension as last dimension
...@@ -63,12 +69,12 @@ def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify ...@@ -63,12 +69,12 @@ def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify
return image_inputs return image_inputs
def prepare_video(image_processor_tester, width=10, height=10, numpify=False, torchify=False): def prepare_video(num_frames, num_channels, width=10, height=10, numpify=False, torchify=False):
"""This function prepares a video as a list of PIL images/NumPy arrays/PyTorch tensors.""" """This function prepares a video as a list of PIL images/NumPy arrays/PyTorch tensors."""
video = [] video = []
for i in range(image_processor_tester.num_frames): for i in range(num_frames):
video.append(np.random.randint(255, size=(image_processor_tester.num_channels, width, height), dtype=np.uint8)) video.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
if not numpify and not torchify: if not numpify and not torchify:
# PIL expects the channel dimension as last dimension # PIL expects the channel dimension as last dimension
...@@ -80,7 +86,16 @@ def prepare_video(image_processor_tester, width=10, height=10, numpify=False, to ...@@ -80,7 +86,16 @@ def prepare_video(image_processor_tester, width=10, height=10, numpify=False, to
return video return video
def prepare_video_inputs(image_processor_tester, equal_resolution=False, numpify=False, torchify=False): def prepare_video_inputs(
batch_size,
num_frames,
num_channels,
min_resolution,
max_resolution,
equal_resolution=False,
numpify=False,
torchify=False,
):
"""This function prepares a batch of videos: a list of list of PIL images, or a list of list of numpy arrays if """This function prepares a batch of videos: a list of list of PIL images, or a list of list of numpy arrays if
one specifies numpify=True, or a list of list of PyTorch tensors if one specifies torchify=True. one specifies numpify=True, or a list of list of PyTorch tensors if one specifies torchify=True.
...@@ -90,15 +105,14 @@ def prepare_video_inputs(image_processor_tester, equal_resolution=False, numpify ...@@ -90,15 +105,14 @@ def prepare_video_inputs(image_processor_tester, equal_resolution=False, numpify
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
video_inputs = [] video_inputs = []
for i in range(image_processor_tester.batch_size): for i in range(batch_size):
if equal_resolution: if equal_resolution:
width = height = image_processor_tester.max_resolution width = height = max_resolution
else: else:
width, height = np.random.choice( width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
np.arange(image_processor_tester.min_resolution, image_processor_tester.max_resolution), 2
)
video = prepare_video( video = prepare_video(
image_processor_tester=image_processor_tester, num_frames=num_frames,
num_channels=num_channels,
width=width, width=width,
height=height, height=height,
numpify=numpify, numpify=numpify,
...@@ -109,7 +123,7 @@ def prepare_video_inputs(image_processor_tester, equal_resolution=False, numpify ...@@ -109,7 +123,7 @@ def prepare_video_inputs(image_processor_tester, equal_resolution=False, numpify
return video_inputs return video_inputs
class ImageProcessingSavingTestMixin: class ImageProcessingTestMixin:
test_cast_dtype = None test_cast_dtype = None
def test_image_processor_to_json_string(self): def test_image_processor_to_json_string(self):
...@@ -150,7 +164,7 @@ class ImageProcessingSavingTestMixin: ...@@ -150,7 +164,7 @@ class ImageProcessingSavingTestMixin:
image_processor = self.image_processing_class(**self.image_processor_dict) image_processor = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
encoding = image_processor(image_inputs, return_tensors="pt") encoding = image_processor(image_inputs, return_tensors="pt")
# for layoutLM compatiblity # for layoutLM compatiblity
...@@ -176,3 +190,65 @@ class ImageProcessingSavingTestMixin: ...@@ -176,3 +190,65 @@ class ImageProcessingSavingTestMixin:
self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16) self.assertEqual(encoding.pixel_values.dtype, torch.float16)
self.assertEqual(encoding.input_ids.dtype, torch.long) self.assertEqual(encoding.input_ids.dtype, torch.long)
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment