Commit 34833427 authored by Zhun Zhong's avatar Zhun Zhong Committed by Francisco Massa
Browse files

Fix bug to RandomErasing (#1095)

* Fix bug to Random Erasing

1. Avoid forever loop for getting parameters of erase.
2. replace' img_b' by 'img_c', because it indicates the channel.
3. replace v = torch.rand([img_c, h, w]) by v = torch.empty([img_c, h, w], dtype=torch.float32).normal_(). Normally distributed achieves better performance.

* add test

* Update test_transforms.py

* Update transforms.py

* Update test_transforms.py

* Update transforms.py

* Update functional.py
parent 957f5145
......@@ -1379,6 +1379,11 @@ class Tester(unittest.TestCase):
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
assert torch.equal(img_re, img)
# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
assert torch.equal(img_re, img)
if __name__ == '__main__':
unittest.main()
......@@ -828,7 +828,7 @@ def erase(img, i, j, h, w, v, inplace=False):
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool,optional): For in-place operations. By default is set False.
inplace(bool, optional): For in-place operations. By default is set False.
Returns:
Tensor Image: Erased image.
......
......@@ -1210,7 +1210,7 @@ class RandomErasing(object):
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace.Default set to False.
inplace: boolean to make this transform inplace. Default set to False.
Returns:
Erased Image.
......@@ -1223,7 +1223,7 @@ class RandomErasing(object):
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
......@@ -1250,10 +1250,10 @@ class RandomErasing(object):
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_b, img_h, img_w = img.shape
img_c, img_h, img_w = img.shape
area = img_h * img_w
while True:
for attempt in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])
......@@ -1266,11 +1266,14 @@ class RandomErasing(object):
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.rand(img_b, h, w)
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
# Return original image
return 0, 0, img_h, img_w, img
def __call__(self, img):
"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment