Unverified Commit ec4e421b authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update expected slices for pillow > 9 (#16117)



* Update expected slices for pillow > 9

* Add expected slices depending on pillow version

* Add different slices depending on pillow version for other models
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 12d1f077
...@@ -19,6 +19,7 @@ import inspect ...@@ -19,6 +19,7 @@ import inspect
import unittest import unittest
from datasets import load_dataset from datasets import load_dataset
from packaging import version
from transformers import BeitConfig from transformers import BeitConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
...@@ -44,6 +45,7 @@ if is_torch_available(): ...@@ -44,6 +45,7 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
import PIL
from PIL import Image from PIL import Image
from transformers import BeitFeatureExtractor from transformers import BeitFeatureExtractor
...@@ -536,12 +538,25 @@ class BeitModelIntegrationTest(unittest.TestCase): ...@@ -536,12 +538,25 @@ class BeitModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 150, 160, 160)) expected_shape = torch.Size((1, 150, 160, 160))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor( is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], if is_pillow_less_than_9:
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], expected_slice = torch.tensor(
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], [
] [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
).to(torch_device) [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
],
device=torch_device,
)
else:
expected_slice = torch.tensor(
[
[[-4.8960, -2.3688, -3.0355], [-2.8478, -0.9836, -1.7418], [-2.9449, -1.3332, -2.1456]],
[[-5.8081, -3.4124, -4.1006], [-3.8561, -2.2081, -3.0323], [-3.8365, -2.4601, -3.3669]],
[[-0.0309, 3.9868, 4.0540], [2.9640, 4.6877, 4.9976], [3.2081, 4.7690, 4.9942]],
],
device=torch_device,
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from datasets import load_dataset from datasets import load_dataset
from packaging import version
from transformers import ViltConfig, is_torch_available, is_vision_available from transformers import ViltConfig, is_torch_available, is_vision_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
...@@ -41,6 +42,7 @@ if is_torch_available(): ...@@ -41,6 +42,7 @@ if is_torch_available():
from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
import PIL
from PIL import Image from PIL import Image
from transformers import ViltProcessor from transformers import ViltProcessor
...@@ -603,5 +605,17 @@ class ViltModelIntegrationTest(unittest.TestCase): ...@@ -603,5 +605,17 @@ class ViltModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size([1, 2]) expected_shape = torch.Size([1, 2])
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-2.4013, 2.9342]).to(torch_device) is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
if is_pillow_less_than_9:
expected_slice = torch.tensor(
[-2.4013, 2.9342],
device=torch_device,
)
else:
expected_slice = torch.tensor(
[-2.3713, 2.9168],
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
...@@ -18,6 +18,7 @@ import tempfile ...@@ -18,6 +18,7 @@ import tempfile
import unittest import unittest
from datasets import load_dataset from datasets import load_dataset
from packaging import version
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
...@@ -51,6 +52,7 @@ if is_torch_available(): ...@@ -51,6 +52,7 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
import PIL
from PIL import Image from PIL import Image
from transformers import TrOCRProcessor, ViTFeatureExtractor from transformers import TrOCRProcessor, ViTFeatureExtractor
...@@ -687,9 +689,18 @@ class TrOCRModelIntegrationTest(unittest.TestCase): ...@@ -687,9 +689,18 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size)) expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor( is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210]
).to(torch_device) if is_pillow_less_than_9:
expected_slice = torch.tensor(
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210],
device=torch_device,
)
else:
expected_slice = torch.tensor(
[-5.6844, -5.8372, 1.1518, -6.8984, 6.8587, -2.4453, 1.2347, -1.0241, -1.9649, -3.9109],
device=torch_device,
)
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment