Unverified Commit 50468715 authored by Rajat Jaiswal's avatar Rajat Jaiswal Committed by GitHub
Browse files

port the rest of test_transforms.py to pytest (#4026)

parent 5d614bd1
......@@ -219,8 +219,9 @@ def freeze_rng_state():
def cycle_over(objs):
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
for idx, obj1 in enumerate(objs):
for obj2 in objs[:idx] + objs[idx + 1:]:
yield obj1, obj2
def int_dtypes():
......
import itertools
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
import math
import random
import numpy as np
......@@ -30,13 +27,10 @@ GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase):
def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
class TestConvertImageDtype:
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes()))
def test_float_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)
......@@ -48,21 +42,20 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
assert abs(actual_min - desired_min) < 1e-7
assert abs(actual_max - desired_max) < 1e-7
def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
@pytest.mark.parametrize('input_dtype', float_dtypes())
@pytest.mark.parametrize('output_dtype', int_dtypes())
def test_float_to_int(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)
......@@ -73,14 +66,13 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)
assert actual_min == desired_min
assert actual_max == desired_max
def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
@pytest.mark.parametrize('input_dtype', int_dtypes())
@pytest.mark.parametrize('output_dtype', float_dtypes())
def test_int_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)
......@@ -92,19 +84,17 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)
assert abs(actual_min - desired_min) < 1e-7
assert actual_min >= desired_min
assert abs(actual_max - desired_max) < 1e-7
assert actual_max <= desired_max
def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
def test_dtype_int_to_int(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)
......@@ -128,19 +118,18 @@ class Tester(unittest.TestCase):
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)
assert actual_min == desired_min
assert actual_max == (desired_max + error_term)
def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
def test_int_to_int_consistency(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue
return
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))
......@@ -148,8 +137,8 @@ class Tester(unittest.TestCase):
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max
self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)
assert actual_min == desired_min
assert actual_max == desired_max
@pytest.mark.skipif(accimage is None, reason="accimage not available")
......@@ -2120,4 +2109,4 @@ def test_random_affine():
if __name__ == '__main__':
unittest.main()
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment