Unverified Commit 75f5b57e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[BC-breaking] RandomErasing is now scriptable (#2386)

* Related to #2292
- RandomErasing is not scriptable

* Fixed code according to review comments
- added additional checking of value vs img num_channels
parent e757d521
......@@ -1618,38 +1618,64 @@ class Tester(unittest.TestCase):
def test_random_erasing(self):
"""Unit tests for random erasing transform"""
img = torch.rand([3, 60, 60])
# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0.2)
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)
# Test Set 2: Check if the unerased region is preserved
orig_unerased = img.clone()
orig_unerased[:, i:i + h, j:j + w] = 0
output_unerased = img_output.clone()
output_unerased[:, i:i + h, j:j + w] = 0
self.assertTrue(torch.equal(orig_unerased, output_unerased))
# Test Set 3: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img)
self.assertEqual(img_re.size(0), 3)
# Test Set 4: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
self.assertEqual(img_re.size(0), 3)
# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
self.assertTrue(torch.equal(img_re, img))
# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
self.assertTrue(torch.equal(img_re, img))
for is_scripted in [False, True]:
torch.manual_seed(12)
img = torch.rand(3, 60, 60)
# Test Set 0: invalid value
random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
img_re = random_erasing(img)
# Test Set 1: Erasing with int value
random_erasing = transforms.RandomErasing(value=0.2)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
i, j, h, w, v = transforms.RandomErasing.get_params(
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)
# Test Set 2: Check if the unerased region is preserved
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = random_erasing.value
self.assertTrue(torch.equal(true_output, img_output))
# Test Set 3: Erasing with random value
random_erasing = transforms.RandomErasing(value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
# Test Set 4: Erasing with tuple value
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
self.assertTrue(torch.equal(true_output, img_output))
# Test Set 5: Testing the inplace behaviour
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))
# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))
if __name__ == '__main__':
......
......@@ -950,7 +950,7 @@ def to_grayscale(img, num_output_channels=1):
return img
def erase(img, i, j, h, w, v, inplace=False):
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
""" Erase the input Tensor Image with given value.
Args:
......
......@@ -3,7 +3,7 @@ import numbers
import random
import warnings
from collections.abc import Sequence, Iterable
from typing import Tuple
from typing import Tuple, List, Optional
import numpy as np
import torch
......@@ -1343,7 +1343,7 @@ class RandomGrayscale(object):
return self.__class__.__name__ + '(p={0})'.format(self.p)
class RandomErasing(object):
class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
......@@ -1370,13 +1370,21 @@ class RandomErasing(object):
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1")
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1")
raise ValueError("Random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
......@@ -1385,13 +1393,18 @@ class RandomErasing(object):
self.inplace = inplace
@staticmethod
def get_params(img, scale, ratio, value=0):
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing.
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
scale (tuple or list): range of proportion of erased area against input image.
ratio (tuple or list): range of aspect ratio of erased area.
value (list, optional): erasing value. If None, it is interpreted as "random"
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
i.e. ``value[0]``.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
......@@ -1400,27 +1413,27 @@ class RandomErasing(object):
area = img_h * img_w
for _ in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
if h < img_h and w < img_w:
i = random.randint(0, img_h - h)
j = random.randint(0, img_w - w)
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
i = torch.randint(0, img_h - h, size=(1, )).item()
j = torch.randint(0, img_w - w, size=(1, )).item()
return i, j, h, w, v
# Return original image
return 0, 0, img_h, img_w, img
def __call__(self, img):
def forward(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
......@@ -1428,7 +1441,24 @@ class RandomErasing(object):
Returns:
img (Tensor): Erased Tensor image.
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
if torch.rand(1) < self.p:
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value, ]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)".format(img.shape[-3])
)
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment