Commit fbe4ad58 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by GitHub
Browse files

add RandomVerticalFlip transform (#262)

add RandomVerticalFlip transform
parent 459dc59e
...@@ -9,6 +9,11 @@ try: ...@@ -9,6 +9,11 @@ try:
except ImportError: except ImportError:
accimage = None accimage = None
try:
from scipy import stats
except ImportError:
stats = None
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg' GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
...@@ -327,6 +332,34 @@ class Tester(unittest.TestCase): ...@@ -327,6 +332,34 @@ class Tester(unittest.TestCase):
assert img.mode == 'I' assert img.mode == 'I'
assert np.allclose(img, img_data[:, :, 0]) assert np.allclose(img, img_data[:, :, 0])
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
num_vertical = 0
for _ in range(100):
out = transforms.RandomVerticalFlip()(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, 100, p=0.5)
assert p_value > 0.05
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip(self):
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
num_horizontal = 0
for _ in range(100):
out = transforms.RandomHorizontalFlip()(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, 100, p=0.5)
assert p_value > 0.05
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -547,6 +547,22 @@ class RandomHorizontalFlip(object): ...@@ -547,6 +547,22 @@ class RandomHorizontalFlip(object):
return img return img
class RandomVerticalFlip(object):
"""Vertically flip the given PIL.Image randomly with a probability of 0.5"""
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be flipped.
Returns:
PIL.Image: Randomly flipped image.
"""
if random.random() < 0.5:
return img.transpose(Image.FLIP_TOP_BOTTOM)
return img
class RandomSizedCrop(object): class RandomSizedCrop(object):
"""Crop the given PIL.Image to random size and aspect ratio. """Crop the given PIL.Image to random size and aspect ratio.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment