"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "8cd3677113004ac5ae2df4afab0e48b5e83443da"
Unverified Commit 0f67ba1d authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add ViTImageProcessorFast to tests (#31424)

* Add ViTImageProcessor to tests

* Correct data format

* Review comments
parent aab08297
...@@ -72,6 +72,8 @@ class Swin2SRImageProcessingTester(unittest.TestCase): ...@@ -72,6 +72,8 @@ class Swin2SRImageProcessingTester(unittest.TestCase):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
input_width, input_height = img.size input_width, input_height = img.size
elif isinstance(img, np.ndarray):
input_height, input_width = img.shape[-3:-1]
else: else:
input_height, input_width = img.shape[-2:] input_height, input_width = img.shape[-2:]
...@@ -160,7 +162,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -160,7 +162,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Test not batched input # Test not batched input
encoded_images = image_processing( encoded_images = image_processing(
image_inputs[0], return_tensors="pt", input_data_format="channels_first" image_inputs[0], return_tensors="pt", input_data_format="channels_last"
).pixel_values ).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
......
...@@ -285,7 +285,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) ...@@ -285,7 +285,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoded_images = image_processor( encoded_images = image_processor(
image_inputs[0], image_inputs[0],
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values_images ).pixel_values_images
...@@ -296,7 +296,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) ...@@ -296,7 +296,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoded_images = image_processor( encoded_images = image_processor(
image_inputs, image_inputs,
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values_images ).pixel_values_images
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
...@@ -78,6 +80,8 @@ class ViltImageProcessingTester(unittest.TestCase): ...@@ -78,6 +80,8 @@ class ViltImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
scale = size / min(w, h) scale = size / min(w, h)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
...@@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im ...@@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available(): if is_vision_available():
from transformers import ViTImageProcessor from transformers import ViTImageProcessor
if is_torchvision_available():
from transformers import ViTImageProcessorFast
class ViTImageProcessingTester(unittest.TestCase): class ViTImageProcessingTester(unittest.TestCase):
def __init__( def __init__(
...@@ -82,6 +85,7 @@ class ViTImageProcessingTester(unittest.TestCase): ...@@ -82,6 +85,7 @@ class ViTImageProcessingTester(unittest.TestCase):
@require_vision @require_vision
class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViTImageProcessor if is_vision_available() else None image_processing_class = ViTImageProcessor if is_vision_available() else None
fast_image_processing_class = ViTImageProcessorFast if is_torchvision_available() else None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -18,6 +18,7 @@ import json ...@@ -18,6 +18,7 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from parameterized import parameterized from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
...@@ -89,6 +90,8 @@ class YolosImageProcessingTester(unittest.TestCase): ...@@ -89,6 +90,8 @@ class YolosImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
width, height = image.size width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[0], image.shape[1]
else: else:
height, width = image.shape[1], image.shape[2] height, width = image.shape[1], image.shape[2]
......
...@@ -18,7 +18,9 @@ import json ...@@ -18,7 +18,9 @@ import json
import os import os
import pathlib import pathlib
import tempfile import tempfile
import time
import numpy as np
import requests import requests
from transformers import AutoImageProcessor, BatchFeature from transformers import AutoImageProcessor, BatchFeature
...@@ -28,7 +30,6 @@ from transformers.utils import is_torch_available, is_vision_available ...@@ -28,7 +30,6 @@ from transformers.utils import is_torch_available, is_vision_available
if is_torch_available(): if is_torch_available():
import numpy as np
import torch import torch
if is_vision_available(): if is_vision_available():
...@@ -72,6 +73,10 @@ def prepare_image_inputs( ...@@ -72,6 +73,10 @@ def prepare_image_inputs(
if torchify: if torchify:
image_inputs = [torch.from_numpy(image) for image in image_inputs] image_inputs = [torch.from_numpy(image) for image in image_inputs]
if numpify:
# Numpy images are typically in channels last format
image_inputs = [image.transpose(1, 2, 0) for image in image_inputs]
return image_inputs return image_inputs
...@@ -167,33 +172,28 @@ class ImageProcessingTestMixin: ...@@ -167,33 +172,28 @@ class ImageProcessingTestMixin:
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3)) self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2))
@require_vision @require_vision
@require_torch @require_torch
def test_fast_is_faster_than_slow(self): def test_fast_is_faster_than_slow(self):
import time
def measure_time(self, image_processor, dummy_image):
start = time.time()
_ = image_processor(dummy_image, return_tensors="pt")
return time.time() - start
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
if not self.test_slow_image_processor or not self.test_fast_image_processor: if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest("Skipping speed test") self.skipTest("Skipping speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None: if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping speed test as one of the image processors is not defined") self.skipTest("Skipping speed test as one of the image processors is not defined")
def measure_time(image_processor, image):
start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start
dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8)
image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class()
slow_time = self.measure_time(image_processor_slow, dummy_image) fast_time = measure_time(image_processor_fast, dummy_images)
fast_time = self.measure_time(image_processor_fast, dummy_image) slow_time = measure_time(image_processor_slow, dummy_images)
self.assertLessEqual(fast_time, slow_time) self.assertLessEqual(fast_time, slow_time)
...@@ -238,6 +238,52 @@ class ImageProcessingTestMixin: ...@@ -238,6 +238,52 @@ class ImageProcessingTestMixin:
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_save_load_fast_slow(self):
"Test that we can load a fast image processor from a slow one and vice-versa."
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
# Load fast image processor from slow one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_slow_0.save_pretrained(tmpdirname)
image_processor_fast_0 = self.fast_image_processing_class.from_pretrained(tmpdirname)
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
# Load slow image processor from fast one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
def test_save_load_fast_slow_auto(self):
"Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor."
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
# Load fast image processor from slow one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_slow_0.save_pretrained(tmpdirname)
image_processor_fast_0 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=True)
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
# Load slow image processor from fast one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
def test_init_without_params(self): def test_init_without_params(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
image_processor = image_processing_class() image_processor = image_processing_class()
...@@ -358,7 +404,7 @@ class ImageProcessingTestMixin: ...@@ -358,7 +404,7 @@ class ImageProcessingTestMixin:
encoded_images = image_processor( encoded_images = image_processor(
image_inputs[0], image_inputs[0],
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values ).pixel_values
...@@ -369,7 +415,7 @@ class ImageProcessingTestMixin: ...@@ -369,7 +415,7 @@ class ImageProcessingTestMixin:
encoded_images = image_processor( encoded_images = image_processor(
image_inputs, image_inputs,
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values ).pixel_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment