Unverified Commit 21426ddc authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port some tests in test_transforms.py to pytest (#3964)

parent a0b44d70
......@@ -403,94 +403,6 @@ class Tester(unittest.TestCase):
with self.assertRaisesRegex(ValueError, r"Required crop size .+ is larger then input image size .+"):
t(img)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_apply(self):
random_state = random.getstate()
random.seed(42)
random_apply_transform = transforms.RandomApply(
[
transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
], p=0.75
)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250
num_applies = 0
for _ in range(num_samples):
out = random_apply_transform(img)
if out != img:
num_applies += 1
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_choice(self):
random_state = random.getstate()
random.seed(42)
random_choice_transform = transforms.RandomChoice(
[
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_resize_15 = 0
num_resize_20 = 0
num_crop_10 = 0
for _ in range(num_samples):
out = random_choice_transform(img)
if out.size == (15, 15):
num_resize_15 += 1
elif out.size == (20, 20):
num_resize_20 += 1
elif out.size == (10, 10):
num_crop_10 += 1
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
self.assertGreater(p_value, 0.0001)
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
self.assertGreater(p_value, 0.0001)
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
self.assertGreater(p_value, 0.0001)
random.setstate(random_state)
# Checking if RandomChoice can be printed as string
random_choice_transform.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_order(self):
random_state = random.getstate()
random.seed(42)
random_order_transform = transforms.RandomOrder(
[
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_normal_order = 0
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
for _ in range(num_samples):
out = random_order_transform(img)
if out == resize_crop_out:
num_normal_order += 1
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
# Checking if RandomOrder can be printed as string
random_order_transform.__repr__()
def test_to_tensor(self):
test_channels = [1, 3, 4]
height, width = 4, 4
......@@ -1994,5 +1906,96 @@ def test_random_grayscale():
trans3.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_apply():
random_state = random.getstate()
random.seed(42)
random_apply_transform = transforms.RandomApply(
[
transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
], p=0.75
)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250
num_applies = 0
for _ in range(num_samples):
out = random_apply_transform(img)
if out != img:
num_applies += 1
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_choice():
random_state = random.getstate()
random.seed(42)
random_choice_transform = transforms.RandomChoice(
[
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_resize_15 = 0
num_resize_20 = 0
num_crop_10 = 0
for _ in range(num_samples):
out = random_choice_transform(img)
if out.size == (15, 15):
num_resize_15 += 1
elif out.size == (20, 20):
num_resize_20 += 1
elif out.size == (10, 10):
num_crop_10 += 1
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
assert p_value > 0.0001
random.setstate(random_state)
# Checking if RandomChoice can be printed as string
random_choice_transform.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_order():
random_state = random.getstate()
random.seed(42)
random_order_transform = transforms.RandomOrder(
[
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_normal_order = 0
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
for _ in range(num_samples):
out = random_order_transform(img)
if out == resize_crop_out:
num_normal_order += 1
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomOrder can be printed as string
random_order_transform.__repr__()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment