Unverified Commit 2a52c2dc authored by Shrill Shrestha's avatar Shrill Shrestha Committed by GitHub
Browse files

Port test_randomness in test_transforms.py to pytest (#3955)

parent f7b4cb04
...@@ -1621,76 +1621,6 @@ class Tester(unittest.TestCase): ...@@ -1621,76 +1621,6 @@ class Tester(unittest.TestCase):
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string") transforms.GaussianBlur(3, "sigma_string")
def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
for p in [0.5, 0.7]:
for config in configs:
inv_img = fn(img, **config)
num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1
p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_invert(self):
self._test_randomness(
F.invert,
transforms.RandomInvert,
[{}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_posterize(self):
self._test_randomness(
F.posterize,
transforms.RandomPosterize,
[{"bits": 4}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)
def test_autoaugment(self): def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy: for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]: for fill in [None, 85, (128, 128, 128)]:
...@@ -1834,6 +1764,36 @@ class TestPad: ...@@ -1834,6 +1764,36 @@ class TestPad:
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False) assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False)
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize('fn, trans, config', [
(F.invert, transforms.RandomInvert, {}),
(F.posterize, transforms.RandomPosterize, {"bits": 4}),
(F.solarize, transforms.RandomSolarize, {"threshold": 192}),
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
(F.autocontrast, transforms.RandomAutocontrast, {}),
(F.equalize, transforms.RandomEqualize, {})])
@pytest.mark.parametrize('p', (.5, .7))
def test_randomness(fn, trans, config, p):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
inv_img = fn(img, **config)
num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1
p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
assert p_value > 0.0001
def test_adjust_brightness(): def test_adjust_brightness():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment