Unverified Commit 50468715 authored by Rajat Jaiswal's avatar Rajat Jaiswal Committed by GitHub
Browse files

port the rest of test_transforms.py to pytest (#4026)

parent 5d614bd1
...@@ -219,8 +219,9 @@ def freeze_rng_state(): ...@@ -219,8 +219,9 @@ def freeze_rng_state():
def cycle_over(objs): def cycle_over(objs):
for idx, obj in enumerate(objs): for idx, obj1 in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:] for obj2 in objs[:idx] + objs[idx + 1:]:
yield obj1, obj2
def int_dtypes(): def int_dtypes():
......
import itertools
import os import os
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
import math import math
import random import random
import numpy as np import numpy as np
...@@ -30,13 +27,10 @@ GRACE_HOPPER = get_file_path_2( ...@@ -30,13 +27,10 @@ GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class TestConvertImageDtype:
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes()))
def test_convert_image_dtype_float_to_float(self): def test_float_to_float(self, input_dtype, output_dtype):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype) transform_script = torch.jit.script(F.convert_image_dtype)
...@@ -48,21 +42,20 @@ class Tester(unittest.TestCase): ...@@ -48,21 +42,20 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0 desired_min, desired_max = 0.0, 1.0
self.assertAlmostEqual(actual_min, desired_min) assert abs(actual_min - desired_min) < 1e-7
self.assertAlmostEqual(actual_max, desired_max) assert abs(actual_max - desired_max) < 1e-7
def test_convert_image_dtype_float_to_int(self): @pytest.mark.parametrize('input_dtype', float_dtypes())
for input_dtype in float_dtypes(): @pytest.mark.parametrize('output_dtype', int_dtypes())
def test_float_to_int(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype) transform_script = torch.jit.script(F.convert_image_dtype)
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64 input_dtype == torch.float64 and output_dtype == torch.int64
): ):
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError):
transform(input_image) transform(input_image)
else: else:
output_image = transform(input_image) output_image = transform(input_image)
...@@ -73,14 +66,13 @@ class Tester(unittest.TestCase): ...@@ -73,14 +66,13 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max desired_min, desired_max = 0, torch.iinfo(output_dtype).max
self.assertEqual(actual_min, desired_min) assert actual_min == desired_min
self.assertEqual(actual_max, desired_max) assert actual_max == desired_max
def test_convert_image_dtype_int_to_float(self): @pytest.mark.parametrize('input_dtype', int_dtypes())
for input_dtype in int_dtypes(): @pytest.mark.parametrize('output_dtype', float_dtypes())
def test_int_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype) transform_script = torch.jit.script(F.convert_image_dtype)
...@@ -92,19 +84,17 @@ class Tester(unittest.TestCase): ...@@ -92,19 +84,17 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0 desired_min, desired_max = 0.0, 1.0
self.assertAlmostEqual(actual_min, desired_min) assert abs(actual_min - desired_min) < 1e-7
self.assertGreaterEqual(actual_min, desired_min) assert actual_min >= desired_min
self.assertAlmostEqual(actual_max, desired_max) assert abs(actual_max - desired_max) < 1e-7
self.assertLessEqual(actual_max, desired_max) assert actual_max <= desired_max
def test_convert_image_dtype_int_to_int(self): @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
for input_dtype, output_dtypes in cycle_over(int_dtypes()): def test_dtype_int_to_int(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype) input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max output_max = torch.iinfo(output_dtype).max
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype) transform_script = torch.jit.script(F.convert_image_dtype)
...@@ -128,19 +118,18 @@ class Tester(unittest.TestCase): ...@@ -128,19 +118,18 @@ class Tester(unittest.TestCase):
else: else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
self.assertEqual(actual_min, desired_min) assert actual_min == desired_min
self.assertEqual(actual_max, desired_max + error_term) assert actual_max == (desired_max + error_term)
def test_convert_image_dtype_int_to_int_consistency(self): @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
for input_dtype, output_dtypes in cycle_over(int_dtypes()): def test_int_to_int_consistency(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype) input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max output_max = torch.iinfo(output_dtype).max
if output_max <= input_max: if output_max <= input_max:
continue return
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype) inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image)) output_image = inverse_transfrom(transform(input_image))
...@@ -148,8 +137,8 @@ class Tester(unittest.TestCase): ...@@ -148,8 +137,8 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max desired_min, desired_max = 0, input_max
self.assertEqual(actual_min, desired_min) assert actual_min == desired_min
self.assertEqual(actual_max, desired_max) assert actual_max == desired_max
@pytest.mark.skipif(accimage is None, reason="accimage not available") @pytest.mark.skipif(accimage is None, reason="accimage not available")
...@@ -2120,4 +2109,4 @@ def test_random_affine(): ...@@ -2120,4 +2109,4 @@ def test_random_affine():
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment