Unverified Commit 729fee5c authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port random flip tests to pytest in test_transforms.py (#3971)

parent 15bebfbc
......@@ -734,121 +734,6 @@ class Tester(unittest.TestCase):
with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'):
transforms.ToPILImage()(np.ones([4, 4, 6]))
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip()(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip(p=0.7)(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip(self):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip()(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip(p=0.7)(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize(self):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
img = torch.rand(channels, 10, 10)
mean = [img[c].mean() for c in range(channels)]
std = [img[c].std() for c in range(channels)]
normalized = transforms.Normalize(mean, std)(img)
self.assertTrue(samples_from_standard_normal(normalized))
random.setstate(random_state)
# Checking if Normalize can be printed as string
transforms.Normalize(mean, std).__repr__()
# Checking the optional in-place behaviour
tensor = torch.rand((1, 16, 16))
tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
assert_equal(tensor, tensor_inplace)
def test_normalize_different_dtype(self):
for dtype1 in [torch.float32, torch.float64]:
img = torch.rand(3, 10, 10, dtype=dtype1)
for dtype2 in [torch.int64, torch.float32, torch.float64]:
mean = torch.tensor([1, 2, 3], dtype=dtype2)
std = torch.tensor([1, 2, 1], dtype=dtype2)
# checks that it doesn't crash
transforms.functional.normalize(img, mean, std)
def test_normalize_3d_tensor(self):
torch.manual_seed(28)
n_channels = 3
img_size = 10
mean = torch.rand(n_channels)
std = torch.rand(n_channels)
img = torch.rand(n_channels, img_size, img_size)
target = F.normalize(img, mean, std)
mean_unsqueezed = mean.view(-1, 1, 1)
std_unsqueezed = std.view(-1, 1, 1)
result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
result2 = F.normalize(img,
mean_unsqueezed.repeat(1, img_size, img_size),
std_unsqueezed.repeat(1, img_size, img_size))
torch.testing.assert_close(target, result1)
torch.testing.assert_close(target, result2)
def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
......@@ -2025,5 +1910,124 @@ def test_random_order():
random_order_transform.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_vertical_flip():
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip()(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip(p=0.7)(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_horizontal_flip():
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip()(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip(p=0.7)(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_normalize():
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
img = torch.rand(channels, 10, 10)
mean = [img[c].mean() for c in range(channels)]
std = [img[c].std() for c in range(channels)]
normalized = transforms.Normalize(mean, std)(img)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
# Checking if Normalize can be printed as string
transforms.Normalize(mean, std).__repr__()
# Checking the optional in-place behaviour
tensor = torch.rand((1, 16, 16))
tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
assert_equal(tensor, tensor_inplace)
@pytest.mark.parametrize('dtype1', [torch.float32, torch.float64])
@pytest.mark.parametrize('dtype2', [torch.int64, torch.float32, torch.float64])
def test_normalize_different_dtype(dtype1, dtype2):
img = torch.rand(3, 10, 10, dtype=dtype1)
mean = torch.tensor([1, 2, 3], dtype=dtype2)
std = torch.tensor([1, 2, 1], dtype=dtype2)
# checks that it doesn't crash
transforms.functional.normalize(img, mean, std)
def test_normalize_3d_tensor():
torch.manual_seed(28)
n_channels = 3
img_size = 10
mean = torch.rand(n_channels)
std = torch.rand(n_channels)
img = torch.rand(n_channels, img_size, img_size)
target = F.normalize(img, mean, std)
mean_unsqueezed = mean.view(-1, 1, 1)
std_unsqueezed = std.view(-1, 1, 1)
result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
result2 = F.normalize(img, mean_unsqueezed.repeat(1, img_size, img_size),
std_unsqueezed.repeat(1, img_size, img_size))
torch.testing.assert_close(target, result1)
torch.testing.assert_close(target, result2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment