Commit 36158026 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Making RandomErasing Out of place by default and improvising tests for the same (#1060)

* test improved

* Update test_transforms.py

* behaviour changes RandomErasing

* test fixes

* linter fix
parent 427633a7
...@@ -1345,20 +1345,33 @@ class Tester(unittest.TestCase): ...@@ -1345,20 +1345,33 @@ class Tester(unittest.TestCase):
def test_random_erasing(self): def test_random_erasing(self):
"""Unit tests for random erasing transform""" """Unit tests for random erasing transform"""
img = torch.rand([3, 224, 224]) img = torch.rand([3, 60, 60])
# Test Set 1: Erasing with int value # Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0)(img) img_re = transforms.RandomErasing(value=0.2)
assert img_re.size(0) == 3 i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
# Test Set 2: Erasing with random value assert img_output.size(0) == 3
# Test Set 2: Check if the unerased region is preserved
orig_unerased = img.clone()
orig_unerased[:, i:i + h, j:j + w] = 0
output_unerased = img_output.clone()
output_unerased[:, i:i + h, j:j + w] = 0
assert torch.equal(orig_unerased, output_unerased)
# Test Set 3: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img) img_re = transforms.RandomErasing(value='random')(img)
assert img_re.size(0) == 3 assert img_re.size(0) == 3
# Test Set 3: Erasing with tuple value # Test Set 4: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img) img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
assert img_re.size(0) == 3 assert img_re.size(0) == 3
# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
assert torch.equal(img_re, img)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1): ...@@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1):
return img return img
def erase(img, i, j, h, w, v): def erase(img, i, j, h, w, v, inplace=False):
""" Erase the input Tensor Image with given value. """ Erase the input Tensor Image with given value.
Args: Args:
...@@ -816,6 +816,7 @@ def erase(img, i, j, h, w, v): ...@@ -816,6 +816,7 @@ def erase(img, i, j, h, w, v):
h (int): Height of the erased region. h (int): Height of the erased region.
w (int): Width of the erased region. w (int): Width of the erased region.
v: Erasing value. v: Erasing value.
inplace(bool,optional): For in-place operations. By default is set False.
Returns: Returns:
Tensor Image: Erased image. Tensor Image: Erased image.
...@@ -823,5 +824,8 @@ def erase(img, i, j, h, w, v): ...@@ -823,5 +824,8 @@ def erase(img, i, j, h, w, v):
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if not inplace:
img = img.clone()
img[:, i:i + h, j:j + w] = v img[:, i:i + h, j:j + w] = v
return img return img
...@@ -1196,6 +1196,8 @@ class RandomErasing(object): ...@@ -1196,6 +1196,8 @@ class RandomErasing(object):
erase all pixels. If a tuple of length 3, it is used to erase erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively. R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values. If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace.Default set to False.
Returns: Returns:
Erased Image. Erased Image.
# Examples: # Examples:
...@@ -1207,7 +1209,7 @@ class RandomErasing(object): ...@@ -1207,7 +1209,7 @@ class RandomErasing(object):
>>> ]) >>> ])
""" """
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0): def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list)) assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)") warnings.warn("range should be of kind (min, max)")
...@@ -1218,6 +1220,7 @@ class RandomErasing(object): ...@@ -1218,6 +1220,7 @@ class RandomErasing(object):
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.value = value self.value = value
self.inplace = inplace
@staticmethod @staticmethod
def get_params(img, scale, ratio, value=0): def get_params(img, scale, ratio, value=0):
...@@ -1261,5 +1264,5 @@ class RandomErasing(object): ...@@ -1261,5 +1264,5 @@ class RandomErasing(object):
""" """
if random.uniform(0, 1) < self.p: if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v) return F.erase(img, x, y, h, w, v, self.inplace)
return img return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment