"vscode:/vscode.git/clone" did not exist on "e98233dde138270f4cc511ded5136e41a2677644"
Unverified Commit 6bca43bb authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Input data format (#25464)

* Add copied from statements for image processors

* Move out rescale and normalize to base image processor

* Remove rescale and normalize from vit (post rebase)

* Update docstrings and tidy up

* PR comments

* Add input_data_format as preprocess argument

* Resolve tests and tidy up

* Remove num_channels argument

* Update doc strings -> default ints not in code formatting
parent a6609caf
......@@ -70,7 +70,7 @@ class BlipImageProcessingTester(unittest.TestCase):
}
def expected_output_image_shape(self, images):
return 3, self.size["height"], self.size["width"]
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
......@@ -135,3 +135,11 @@ class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.Tes
@unittest.skip("BlipImageProcessor does not support 4 channels yet") # FIXME Amy
def test_call_pytorch(self):
return super().test_call_torch()
@unittest.skip("BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_pil(self):
pass
@unittest.skip("BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
pass
......@@ -337,6 +337,11 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_numpy(self):
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
def test_call_numpy_4_channels(self):
self.image_processing_class.num_channels = 4
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
self.image_processing_class.num_channels = 3
def test_call_pytorch(self):
self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
......
......@@ -198,6 +198,10 @@ class ImageGPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
@unittest.skip("ImageGPT assumes clusters for 3 channels")
def test_call_numpy_4_channels(self):
pass
# Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input
def test_call_pytorch(self):
# Initialize image_processing
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment