Unverified Commit 75f5b57e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[BC-breaking] RandomErasing is now scriptable (#2386)

* Related to #2292
- RandomErasing is not scriptable

* Fixed code according to review comments
- added additional checking of value vs img num_channels
parent e757d521
...@@ -1618,38 +1618,64 @@ class Tester(unittest.TestCase): ...@@ -1618,38 +1618,64 @@ class Tester(unittest.TestCase):
def test_random_erasing(self): def test_random_erasing(self):
"""Unit tests for random erasing transform""" """Unit tests for random erasing transform"""
for is_scripted in [False, True]:
img = torch.rand([3, 60, 60]) torch.manual_seed(12)
img = torch.rand(3, 60, 60)
# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0.2) # Test Set 0: invalid value
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value) random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
img_output = F.erase(img, i, j, h, w, v) with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
self.assertEqual(img_output.size(0), 3) img_re = random_erasing(img)
# Test Set 2: Check if the unerased region is preserved # Test Set 1: Erasing with int value
orig_unerased = img.clone() random_erasing = transforms.RandomErasing(value=0.2)
orig_unerased[:, i:i + h, j:j + w] = 0 if is_scripted:
output_unerased = img_output.clone() random_erasing = torch.jit.script(random_erasing)
output_unerased[:, i:i + h, j:j + w] = 0
self.assertTrue(torch.equal(orig_unerased, output_unerased)) i, j, h, w, v = transforms.RandomErasing.get_params(
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
# Test Set 3: Erasing with random value )
img_re = transforms.RandomErasing(value='random')(img) img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_re.size(0), 3) self.assertEqual(img_output.size(0), 3)
# Test Set 4: Erasing with tuple value # Test Set 2: Check if the unerased region is preserved
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img) true_output = img.clone()
self.assertEqual(img_re.size(0), 3) true_output[:, i:i + h, j:j + w] = random_erasing.value
self.assertTrue(torch.equal(true_output, img_output))
# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img) # Test Set 3: Erasing with random value
self.assertTrue(torch.equal(img_re, img)) random_erasing = transforms.RandomErasing(value="random")
if is_scripted:
# Test Set 6: Checking when no erased region is selected random_erasing = torch.jit.script(random_erasing)
img = torch.rand([3, 300, 1]) img_re = random_erasing(img)
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
self.assertTrue(torch.equal(img_re, img)) self.assertEqual(img_re.size(0), 3)
# Test Set 4: Erasing with tuple value
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
self.assertTrue(torch.equal(true_output, img_output))
# Test Set 5: Testing the inplace behaviour
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))
# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -950,7 +950,7 @@ def to_grayscale(img, num_output_channels=1): ...@@ -950,7 +950,7 @@ def to_grayscale(img, num_output_channels=1):
return img return img
def erase(img, i, j, h, w, v, inplace=False): def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
""" Erase the input Tensor Image with given value. """ Erase the input Tensor Image with given value.
Args: Args:
......
...@@ -3,7 +3,7 @@ import numbers ...@@ -3,7 +3,7 @@ import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence, Iterable from collections.abc import Sequence, Iterable
from typing import Tuple from typing import Tuple, List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -1343,7 +1343,7 @@ class RandomGrayscale(object): ...@@ -1343,7 +1343,7 @@ class RandomGrayscale(object):
return self.__class__.__name__ + '(p={0})'.format(self.p) return self.__class__.__name__ + '(p={0})'.format(self.p)
class RandomErasing(object): class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an image and erases its pixels. """ Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
...@@ -1370,13 +1370,21 @@ class RandomErasing(object): ...@@ -1370,13 +1370,21 @@ class RandomErasing(object):
""" """
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list)) super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1: if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1") raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1: if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1") raise ValueError("Random erasing probability should be between 0 and 1")
self.p = p self.p = p
self.scale = scale self.scale = scale
...@@ -1385,13 +1393,18 @@ class RandomErasing(object): ...@@ -1385,13 +1393,18 @@ class RandomErasing(object):
self.inplace = inplace self.inplace = inplace
@staticmethod @staticmethod
def get_params(img, scale, ratio, value=0): def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing. """Get parameters for ``erase`` for a random erasing.
Args: Args:
img (Tensor): Tensor image of size (C, H, W) to be erased. img (Tensor): Tensor image of size (C, H, W) to be erased.
scale: range of proportion of erased area against input image. scale (tuple or list): range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area. ratio (tuple or list): range of aspect ratio of erased area.
value (list, optional): erasing value. If None, it is interpreted as "random"
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
i.e. ``value[0]``.
Returns: Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
...@@ -1400,27 +1413,27 @@ class RandomErasing(object): ...@@ -1400,27 +1413,27 @@ class RandomErasing(object):
area = img_h * img_w area = img_h * img_w
for _ in range(10): for _ in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = random.uniform(ratio[0], ratio[1]) aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
h = int(round(math.sqrt(erase_area * aspect_ratio))) h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
if h < img_h and w < img_w: i = torch.randint(0, img_h - h, size=(1, )).item()
i = random.randint(0, img_h - h) j = torch.randint(0, img_w - w, size=(1, )).item()
j = random.randint(0, img_w - w) return i, j, h, w, v
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
# Return original image # Return original image
return 0, 0, img_h, img_w, img return 0, 0, img_h, img_w, img
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (Tensor): Tensor image of size (C, H, W) to be erased. img (Tensor): Tensor image of size (C, H, W) to be erased.
...@@ -1428,7 +1441,24 @@ class RandomErasing(object): ...@@ -1428,7 +1441,24 @@ class RandomErasing(object):
Returns: Returns:
img (Tensor): Erased Tensor image. img (Tensor): Erased Tensor image.
""" """
if random.uniform(0, 1) < self.p: if torch.rand(1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value, ]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)".format(img.shape[-3])
)
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace) return F.erase(img, x, y, h, w, v, self.inplace)
return img return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment