Commit 75003739 authored by vfdev's avatar vfdev Committed by Francisco Massa
Browse files

Add RandomApply, RandomChoice, RandomOrder transformations (#402)

* Add RandomApply, RandomChoice, RandomOrder transformations

* Rename argument `proba` to `p`
parent 59858699
...@@ -258,6 +258,91 @@ class Tester(unittest.TestCase): ...@@ -258,6 +258,91 @@ class Tester(unittest.TestCase):
# Checking if Lambda can be printed as string # Checking if Lambda can be printed as string
trans.__repr__() trans.__repr__()
def test_random_apply(self):
random_state = random.getstate()
random.seed(42)
random_apply_transform = transforms.RandomApply(
[
transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
], p=0.75
)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250
num_applies = 0
for _ in range(num_samples):
out = random_apply_transform(img)
if out != img:
num_applies += 1
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()
def test_random_choice(self):
random_state = random.getstate()
random.seed(42)
random_choice_transform = transforms.RandomChoice(
[
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_resize_15 = 0
num_resize_20 = 0
num_crop_10 = 0
for _ in range(num_samples):
out = random_choice_transform(img)
if out.size == (15, 15):
num_resize_15 += 1
elif out.size == (20, 20):
num_resize_20 += 1
elif out.size == (10, 10):
num_crop_10 += 1
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
assert p_value > 0.0001
random.setstate(random_state)
# Checking if RandomChoice can be printed as string
random_choice_transform.__repr__()
def test_random_order(self):
random_state = random.getstate()
random.seed(42)
random_order_transform = transforms.RandomOrder(
[
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_normal_order = 0
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
for _ in range(num_samples):
out = random_order_transform(img)
if out == resize_crop_out:
num_normal_order += 1
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomOrder can be printed as string
random_order_transform.__repr__()
def test_to_tensor(self): def test_to_tensor(self):
test_channels = [1, 3, 4] test_channels = [1, 3, 4]
height, width = 4, 4 height, width = 4, 4
......
...@@ -16,9 +16,9 @@ import warnings ...@@ -16,9 +16,9 @@ import warnings
from . import functional as F from . import functional as F
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"Grayscale", "RandomGrayscale"] "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
class Compose(object): class Compose(object):
...@@ -261,6 +261,77 @@ class Lambda(object): ...@@ -261,6 +261,77 @@ class Lambda(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class RandomTransforms(object):
"""Base class for a list of transformations with randomness
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class RandomApply(RandomTransforms):
"""Apply randomly a list of transformations with a given probability
Args:
transforms (list or tuple): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, img):
if self.p < random.random():
return img
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order
"""
def __call__(self, img):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img = self.transforms[i](img)
return img
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list
"""
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
class RandomCrop(object): class RandomCrop(object):
"""Crop the given PIL Image at a random location. """Crop the given PIL Image at a random location.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment