Unverified Commit b621e38e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove p-value checks in test_transforms.py (#4756)

* Change test_random_apply

* Change test_random_choice

* Change test_randomness

* took care of RandomVert/HorizFlip

* take care of RandomGrayScale

* minor cleanup

* avoid 0 degree rotation just in case
parent bbfda424
import math
import os
import random
from functools import partial
import numpy as np
import pytest
......@@ -541,9 +542,8 @@ class TestPad:
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize(
"fn, trans, config",
"fn, trans, kwargs",
[
(F.invert, transforms.RandomInvert, {}),
(F.posterize, transforms.RandomPosterize, {"bits": 4}),
......@@ -551,28 +551,26 @@ class TestPad:
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
(F.autocontrast, transforms.RandomAutocontrast, {}),
(F.equalize, transforms.RandomEqualize, {}),
(F.vflip, transforms.RandomVerticalFlip, {}),
(F.hflip, transforms.RandomHorizontalFlip, {}),
(partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}),
],
)
@pytest.mark.parametrize("p", (0.5, 0.7))
def test_randomness(fn, trans, config, p):
random_state = random.getstate()
random.seed(42)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
def test_randomness(fn, trans, kwargs, seed, p):
torch.manual_seed(seed)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
inv_img = fn(img, **config)
expected_transformed_img = fn(img, **kwargs)
randomly_transformed_img = trans(p=p, **kwargs)(img)
num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1
if p == 0:
assert randomly_transformed_img == img
elif p == 1:
assert randomly_transformed_img == expected_transformed_img
p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
assert p_value > 0.0001
trans(**kwargs).__repr__()
class TestToPil:
......@@ -1362,160 +1360,42 @@ def test_to_grayscale():
trans4.__repr__()
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_grayscale():
"""Unit tests for random grayscale transform"""
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
def test_random_apply(p, seed):
torch.manual_seed(seed)
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p)
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
out = random_apply_transform(img)
if p == 0:
assert out == img
elif p == 1:
assert out != img
# Test Set 1: RGB -> 3 channel grayscale
np_rng = np.random.RandomState(0)
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)
num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
gray_np_2 = np.array(gray_pil_2)
if (
np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
and np.array_equal(gray_np, gray_np_2[:, :, 0])
):
num_gray = num_gray + 1
p_value = stats.binom_test(num_gray, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
# Test Set 2: grayscale -> 1 channel grayscale
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)
num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
if np.array_equal(gray_np, gray_np_3):
num_gray = num_gray + 1
p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged
random.setstate(random_state)
assert p_value > 0.0001
# Test set 3: Explicit tests
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)
# Case 3a: RGB -> 3 channel grayscale (grayscaled)
trans2 = transforms.RandomGrayscale(p=1.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0])
# Case 3b: RGB -> 3 channel grayscale (unchanged)
trans2 = transforms.RandomGrayscale(p=0.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(x_np, gray_np_2)
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
trans3 = transforms.RandomGrayscale(p=1.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3)
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
trans3 = transforms.RandomGrayscale(p=0.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3)
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("proba_passthrough", (0, 1))
def test_random_choice(proba_passthrough, seed):
random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_apply():
random_state = random.getstate()
random.seed(42)
random_apply_transform = transforms.RandomApply(
random_choice_transform = transforms.RandomChoice(
[
transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
lambda x: x, # passthrough
transforms.RandomRotation((1, 45)),
],
p=0.75,
p=[proba_passthrough, 1 - proba_passthrough],
)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250
num_applies = 0
for _ in range(num_samples):
out = random_apply_transform(img)
if out != img:
num_applies += 1
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_choice():
random_state = random.getstate()
random.seed(42)
random_choice_transform = transforms.RandomChoice(
[transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_resize_15 = 0
num_resize_20 = 0
num_crop_10 = 0
for _ in range(num_samples):
out = random_choice_transform(img)
if out.size == (15, 15):
num_resize_15 += 1
elif out.size == (20, 20):
num_resize_20 += 1
elif out.size == (10, 10):
num_crop_10 += 1
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
assert p_value > 0.0001
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
out = random_choice_transform(img)
if proba_passthrough == 1:
assert out == img
elif proba_passthrough == 0:
assert out != img
random.setstate(random_state)
# Checking if RandomChoice can be printed as string
random_choice_transform.__repr__()
......@@ -1888,6 +1768,7 @@ def test_random_erasing():
tol = 0.05
assert 1 / 3 - tol <= aspect_ratio <= 3 + tol
# Make sure that h > w and h < w are equaly likely (log-scale sampling)
aspect_ratios = []
random.seed(42)
trial = 1000
......@@ -2011,72 +1892,6 @@ def test_randomperspective_fill(mode):
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_vertical_flip():
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip()(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip(p=0.7)(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_horizontal_flip():
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip()(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip(p=0.7)(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_normalize():
def samples_from_standard_normal(tensor):
......
......@@ -562,7 +562,7 @@ class RandomChoice(RandomTransforms):
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument transforms should be a sequence")
raise TypeError("Argument p should be a sequence")
self.p = p
def __call__(self, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment