Commit 5b433d83 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Francisco Massa
Browse files

Increase stability of random_horizontal_flip and random_vertical_flip tests (#283)

* increase num_samples, reduce alpha and fix seed for hflip and vflip tests

* flake8 fixes
parent 618072bf
...@@ -396,31 +396,39 @@ class Tester(unittest.TestCase): ...@@ -396,31 +396,39 @@ class Tester(unittest.TestCase):
@unittest.skipIf(stats is None, 'scipy.stats not available') @unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM) vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
num_samples = 250
num_vertical = 0 num_vertical = 0
for _ in range(100): for _ in range(num_samples):
out = transforms.RandomVerticalFlip()(img) out = transforms.RandomVerticalFlip()(img)
if out == vimg: if out == vimg:
num_vertical += 1 num_vertical += 1
p_value = stats.binom_test(num_vertical, 100, p=0.5) p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
assert p_value > 0.05 random.setstate(random_state)
assert p_value > 0.0001
@unittest.skipIf(stats is None, 'scipy.stats not available') @unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip(self): def test_random_horizontal_flip(self):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT) himg = img.transpose(Image.FLIP_LEFT_RIGHT)
num_samples = 250
num_horizontal = 0 num_horizontal = 0
for _ in range(100): for _ in range(num_samples):
out = transforms.RandomHorizontalFlip()(img) out = transforms.RandomHorizontalFlip()(img)
if out == himg: if out == himg:
num_horizontal += 1 num_horizontal += 1
p_value = stats.binom_test(num_horizontal, 100, p=0.5) p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
assert p_value > 0.05 random.setstate(random_state)
assert p_value > 0.0001
def test_adjust_brightness(self): def test_adjust_brightness(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment