Commit 793c4e82 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Minor optimization to RandomErasing (#1087)

* Minor optimization to RandomErasing

This PR adds an additional check on `p` argument and prevents computing `img.shape` multiple times.

* linting
parent 9dfca9af
...@@ -1229,6 +1229,8 @@ class RandomErasing(object): ...@@ -1229,6 +1229,8 @@ class RandomErasing(object):
warnings.warn("range should be of kind (min, max)") warnings.warn("range should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1: if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1") raise ValueError("range of scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1")
self.p = p self.p = p
self.scale = scale self.scale = scale
...@@ -1248,7 +1250,8 @@ class RandomErasing(object): ...@@ -1248,7 +1250,8 @@ class RandomErasing(object):
Returns: Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
""" """
area = img.shape[1] * img.shape[2] img_b, img_h, img_w = img.shape
area = img_h * img_w
while True: while True:
erase_area = random.uniform(scale[0], scale[1]) * area erase_area = random.uniform(scale[0], scale[1]) * area
...@@ -1257,13 +1260,13 @@ class RandomErasing(object): ...@@ -1257,13 +1260,13 @@ class RandomErasing(object):
h = int(round(math.sqrt(erase_area * aspect_ratio))) h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio)))
if h < img.shape[1] and w < img.shape[2]: if h < img_h and w < img_w:
i = random.randint(0, img.shape[1] - h) i = random.randint(0, img_h - h)
j = random.randint(0, img.shape[2] - w) j = random.randint(0, img_w - w)
if isinstance(value, numbers.Number): if isinstance(value, numbers.Number):
v = value v = value
elif isinstance(value, torch._six.string_classes): elif isinstance(value, torch._six.string_classes):
v = torch.rand(img.size()[0], h, w) v = torch.rand(img_b, h, w)
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w) v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v return i, j, h, w, v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment