Commit 59858699 authored by vfdev's avatar vfdev Committed by Francisco Massa
Browse files

Probability parameter in RandomHorizontalFlip, RandomHorizontalFlip (#417)

* Set probability as configuration parameter in RandomHorizontalFlip and RandomHorizontalFlip (#414)

* Fix documentation
parent 22385bc6
...@@ -475,6 +475,17 @@ class Tester(unittest.TestCase): ...@@ -475,6 +475,17 @@ class Tester(unittest.TestCase):
random.setstate(random_state) random.setstate(random_state)
assert p_value > 0.0001 assert p_value > 0.0001
num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip(p=0.7)(img)
if out == vimg:
num_vertical += 1
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomVerticalFlip can be printed as string # Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__() transforms.RandomVerticalFlip().__repr__()
...@@ -496,6 +507,17 @@ class Tester(unittest.TestCase): ...@@ -496,6 +507,17 @@ class Tester(unittest.TestCase):
random.setstate(random_state) random.setstate(random_state)
assert p_value > 0.0001 assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip(p=0.7)(img)
if out == himg:
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomHorizontalFlip can be printed as string # Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__() transforms.RandomHorizontalFlip().__repr__()
......
...@@ -321,7 +321,14 @@ class RandomCrop(object): ...@@ -321,7 +321,14 @@ class RandomCrop(object):
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a probability of 0.5.""" """Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img): def __call__(self, img):
""" """
...@@ -331,16 +338,23 @@ class RandomHorizontalFlip(object): ...@@ -331,16 +338,23 @@ class RandomHorizontalFlip(object):
Returns: Returns:
PIL Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < self.p:
return F.hflip(img) return F.hflip(img)
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomVerticalFlip(object): class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a probability of 0.5.""" """Vertically flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img): def __call__(self, img):
""" """
...@@ -350,12 +364,12 @@ class RandomVerticalFlip(object): ...@@ -350,12 +364,12 @@ class RandomVerticalFlip(object):
Returns: Returns:
PIL Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < self.p:
return F.vflip(img) return F.vflip(img)
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomResizedCrop(object): class RandomResizedCrop(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment