Commit 3254560b authored by Zhun Zhong's avatar Zhun Zhong Committed by Francisco Massa
Browse files

transforms: add Random Erasing for image augmentation (#909)

* add erase function

* add Random Erasing

* Update transforms.py

* Update transforms.py

* add test for random erasing

* Update test_transforms.py

* fix flake8

* Update test_transforms.py

* Update functional.py

* Update test_transforms.py

* fix bug for per-pixel erasing

* Update transforms.py

* specific for coordinate (x, y)

* add raise TypeError for img

* Update transforms.py

* Update transforms.rst
parent d38600fe
...@@ -63,6 +63,8 @@ Transforms on torch.\*Tensor ...@@ -63,6 +63,8 @@ Transforms on torch.\*Tensor
:members: __call__ :members: __call__
:special-members: :special-members:
.. autoclass:: RandomErasing
Conversion Transforms Conversion Transforms
--------------------- ---------------------
......
...@@ -1342,6 +1342,23 @@ class Tester(unittest.TestCase): ...@@ -1342,6 +1342,23 @@ class Tester(unittest.TestCase):
# Checking if RandomGrayscale can be printed as string # Checking if RandomGrayscale can be printed as string
trans3.__repr__() trans3.__repr__()
def test_random_erasing(self):
"""Unit tests for random erasing transform"""
img = torch.rand([3, 224, 224])
# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0)(img)
assert img_re.size(0) == 3
# Test Set 2: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img)
assert img_re.size(0) == 3
# Test Set 3: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
assert img_re.size(0) == 3
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -804,3 +804,24 @@ def to_grayscale(img, num_output_channels=1): ...@@ -804,3 +804,24 @@ def to_grayscale(img, num_output_channels=1):
raise ValueError('num_output_channels should be either 1 or 3') raise ValueError('num_output_channels should be either 1 or 3')
return img return img
def erase(img, i, j, h, w, v):
""" Erase the input Tensor Image with given value.
Args:
img (Tensor Image): Tensor image of size (C, H, W) to be erased
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
img[:, i:i + h, j:j + w] = v
return img
...@@ -28,7 +28,7 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", ...@@ -28,7 +28,7 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective"] "RandomPerspective", "RandomErasing"]
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: 'PIL.Image.NEAREST',
...@@ -1182,3 +1182,84 @@ class RandomGrayscale(object): ...@@ -1182,3 +1182,84 @@ class RandomGrayscale(object):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p) return self.__class__.__name__ + '(p={0})'.format(self.p)
class RandomErasing(object):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
Returns:
Erased Image.
# Examples:
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.ToTensor(),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
@staticmethod
def get_params(img, scale, ratio, value=0):
"""Get parameters for ``erase`` for a random erasing.
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
area = img.shape[1] * img.shape[2]
while True:
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if h < img.shape[1] and w < img.shape[2]:
i = random.randint(0, img.shape[1] - h)
j = random.randint(0, img.shape[2] - w)
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.rand(img.size()[0], h, w)
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
Returns:
img (Tensor): Erased Tensor image.
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment