Commit 36158026 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Making RandomErasing Out of place by default and improvising tests for the same (#1060)

* test improved

* Update test_transforms.py

* behaviour changes RandomErasing

* test fixes

* linter fix
parent 427633a7
......@@ -1345,20 +1345,33 @@ class Tester(unittest.TestCase):
def test_random_erasing(self):
"""Unit tests for random erasing transform"""
img = torch.rand([3, 224, 224])
img = torch.rand([3, 60, 60])
# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0)(img)
assert img_re.size(0) == 3
# Test Set 2: Erasing with random value
img_re = transforms.RandomErasing(value=0.2)
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
assert img_output.size(0) == 3
# Test Set 2: Check if the unerased region is preserved
orig_unerased = img.clone()
orig_unerased[:, i:i + h, j:j + w] = 0
output_unerased = img_output.clone()
output_unerased[:, i:i + h, j:j + w] = 0
assert torch.equal(orig_unerased, output_unerased)
# Test Set 3: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img)
assert img_re.size(0) == 3
# Test Set 3: Erasing with tuple value
# Test Set 4: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
assert img_re.size(0) == 3
# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
assert torch.equal(img_re, img)
if __name__ == '__main__':
unittest.main()
......@@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1):
return img
def erase(img, i, j, h, w, v):
def erase(img, i, j, h, w, v, inplace=False):
""" Erase the input Tensor Image with given value.
Args:
......@@ -816,6 +816,7 @@ def erase(img, i, j, h, w, v):
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool,optional): For in-place operations. By default is set False.
Returns:
Tensor Image: Erased image.
......@@ -823,5 +824,8 @@ def erase(img, i, j, h, w, v):
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if not inplace:
img = img.clone()
img[:, i:i + h, j:j + w] = v
return img
......@@ -1196,6 +1196,8 @@ class RandomErasing(object):
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace.Default set to False.
Returns:
Erased Image.
# Examples:
......@@ -1207,7 +1209,7 @@ class RandomErasing(object):
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
......@@ -1218,6 +1220,7 @@ class RandomErasing(object):
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
@staticmethod
def get_params(img, scale, ratio, value=0):
......@@ -1261,5 +1264,5 @@ class RandomErasing(object):
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v)
return F.erase(img, x, y, h, w, v, self.inplace)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment