Unverified Commit f53fe35b authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fast image processor (#28847)



* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py
Co-authored-by: default avatarPavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarPavel Iakubovskii <qubvel@gmail.com>
parent edc1dffd
...@@ -94,6 +94,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -94,6 +94,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VitMatteImageProcessor if is_vision_available() else None image_processing_class = VitMatteImageProcessor if is_vision_available() else None
def setUp(self): def setUp(self):
super().setUp()
self.image_processor_tester = VitMatteImageProcessingTester(self) self.image_processor_tester = VitMatteImageProcessingTester(self)
@property @property
......
...@@ -99,6 +99,7 @@ class VivitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -99,6 +99,7 @@ class VivitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VivitImageProcessor if is_vision_available() else None image_processing_class = VivitImageProcessor if is_vision_available() else None
def setUp(self): def setUp(self):
super().setUp()
self.image_processor_tester = VivitImageProcessingTester(self) self.image_processor_tester = VivitImageProcessingTester(self)
@property @property
......
...@@ -143,6 +143,7 @@ class YolosImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMix ...@@ -143,6 +143,7 @@ class YolosImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMix
image_processing_class = YolosImageProcessor if is_vision_available() else None image_processing_class = YolosImageProcessor if is_vision_available() else None
def setUp(self): def setUp(self):
super().setUp()
self.image_processor_tester = YolosImageProcessingTester(self) self.image_processor_tester = YolosImageProcessingTester(self)
@property @property
......
...@@ -19,7 +19,9 @@ import os ...@@ -19,7 +19,9 @@ import os
import pathlib import pathlib
import tempfile import tempfile
from transformers import BatchFeature import requests
from transformers import AutoImageProcessor, BatchFeature
from transformers.image_utils import AnnotationFormat, AnnotionFormat from transformers.image_utils import AnnotationFormat, AnnotionFormat
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
...@@ -129,43 +131,125 @@ def prepare_video_inputs( ...@@ -129,43 +131,125 @@ def prepare_video_inputs(
class ImageProcessingTestMixin: class ImageProcessingTestMixin:
test_cast_dtype = None test_cast_dtype = None
image_processing_class = None
fast_image_processing_class = None
image_processors_list = None
test_slow_image_processor = True
test_fast_image_processor = True
def setUp(self):
image_processor_list = []
if self.test_slow_image_processor and self.image_processing_class:
image_processor_list.append(self.image_processing_class)
if self.test_fast_image_processor and self.fast_image_processing_class:
image_processor_list.append(self.fast_image_processing_class)
self.image_processor_list = image_processor_list
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest("Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping slow/fast equivalence test as one of the image processors is not defined")
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3))
@require_vision
@require_torch
def test_fast_is_faster_than_slow(self):
import time
def measure_time(self, image_processor, dummy_image):
start = time.time()
_ = image_processor(dummy_image, return_tensors="pt")
return time.time() - start
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest("Skipping speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping speed test as one of the image processors is not defined")
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
slow_time = self.measure_time(image_processor_slow, dummy_image)
fast_time = self.measure_time(image_processor_fast, dummy_image)
self.assertLessEqual(fast_time, slow_time)
def test_image_processor_to_json_string(self): def test_image_processor_to_json_string(self):
image_processor = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
obj = json.loads(image_processor.to_json_string()) obj = json.loads(image_processor.to_json_string())
for key, value in self.image_processor_dict.items(): for key, value in self.image_processor_dict.items():
self.assertEqual(obj[key], value) self.assertEqual(obj[key], value)
def test_image_processor_to_json_file(self): def test_image_processor_to_json_file(self):
image_processor_first = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
image_processor_first = image_processing_class(**self.image_processor_dict)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "image_processor.json") json_file_path = os.path.join(tmpdirname, "image_processor.json")
image_processor_first.to_json_file(json_file_path) image_processor_first.to_json_file(json_file_path)
image_processor_second = self.image_processing_class.from_json_file(json_file_path) image_processor_second = image_processing_class.from_json_file(json_file_path)
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_image_processor_from_and_save_pretrained(self): def test_image_processor_from_and_save_pretrained(self):
image_processor_first = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
image_processor_first = image_processing_class(**self.image_processor_dict)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = image_processor_first.save_pretrained(tmpdirname)[0] saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file) check_json_file_has_correct_format(saved_file)
image_processor_second = self.image_processing_class.from_pretrained(tmpdirname) image_processor_second = image_processing_class.from_pretrained(tmpdirname)
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_image_processor_save_load_with_autoimageprocessor(self):
for image_processing_class in self.image_processor_list:
image_processor_first = image_processing_class(**self.image_processor_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_init_without_params(self): def test_init_without_params(self):
image_processor = self.image_processing_class() for image_processing_class in self.image_processor_list:
image_processor = image_processing_class()
self.assertIsNotNone(image_processor) self.assertIsNotNone(image_processor)
@require_torch @require_torch
@require_vision @require_vision
def test_cast_dtype_device(self): def test_cast_dtype_device(self):
for image_processing_class in self.image_processor_list:
if self.test_cast_dtype is not None: if self.test_cast_dtype is not None:
# Initialize image_processor # Initialize image_processor
image_processor = self.image_processing_class(**self.image_processor_dict) image_processor = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
...@@ -196,8 +280,9 @@ class ImageProcessingTestMixin: ...@@ -196,8 +280,9 @@ class ImageProcessingTestMixin:
self.assertEqual(encoding.input_ids.dtype, torch.long) self.assertEqual(encoding.input_ids.dtype, torch.long)
def test_call_pil(self): def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing # Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images # create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs: for image in image_inputs:
...@@ -216,8 +301,9 @@ class ImageProcessingTestMixin: ...@@ -216,8 +301,9 @@ class ImageProcessingTestMixin:
) )
def test_call_numpy(self): def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing # Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors # create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in image_inputs: for image in image_inputs:
...@@ -236,8 +322,9 @@ class ImageProcessingTestMixin: ...@@ -236,8 +322,9 @@ class ImageProcessingTestMixin:
) )
def test_call_pytorch(self): def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing # Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
...@@ -258,9 +345,10 @@ class ImageProcessingTestMixin: ...@@ -258,9 +345,10 @@ class ImageProcessingTestMixin:
) )
def test_call_numpy_4_channels(self): def test_call_numpy_4_channels(self):
for image_processing_class in self.image_processor_list:
# Test that can process images which have an arbitrary number of channels # Test that can process images which have an arbitrary number of channels
# Initialize image_processing # Initialize image_processing
image_processor = self.image_processing_class(**self.image_processor_dict) image_processor = image_processing_class(**self.image_processor_dict)
# create random numpy tensors # create random numpy tensors
self.image_processor_tester.num_channels = 4 self.image_processor_tester.num_channels = 4
...@@ -291,7 +379,8 @@ class ImageProcessingTestMixin: ...@@ -291,7 +379,8 @@ class ImageProcessingTestMixin:
) )
def test_image_processor_preprocess_arguments(self): def test_image_processor_preprocess_arguments(self):
image_processor = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"): if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"):
preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args
preprocess_parameter_names.remove("self") preprocess_parameter_names.remove("self")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment