Unverified Commit c84dbfad authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Speed up Augment Transform Classes (#6835)

* Moving value estimation of `RandomErasing` from runtime to constructor.

* Speed up mixing on MixUp/Cutmix and small optimization on SimpleCopyPaste.

* Apply nits.
parent 8e0e7157
...@@ -40,7 +40,14 @@ class RandomErasing(_RandomApplyTransform): ...@@ -40,7 +40,14 @@ class RandomErasing(_RandomApplyTransform):
raise ValueError("Scale should be between 0 and 1") raise ValueError("Scale should be between 0 and 1")
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.value = value if isinstance(value, (int, float)):
self.value = [value]
elif isinstance(value, str):
self.value = None
elif isinstance(value, tuple):
self.value = list(value)
else:
self.value = value
self.inplace = inplace self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
...@@ -48,16 +55,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -48,16 +55,7 @@ class RandomErasing(_RandomApplyTransform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs) img_c, img_h, img_w = query_chw(flat_inputs)
if isinstance(self.value, (int, float)): if self.value is not None and not (len(self.value) in (1, img_c)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img_c)):
raise ValueError( raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
) )
...@@ -79,10 +77,10 @@ class RandomErasing(_RandomApplyTransform): ...@@ -79,10 +77,10 @@ class RandomErasing(_RandomApplyTransform):
if not (h < img_h and w < img_w): if not (h < img_h and w < img_w):
continue continue
if value is None: if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else: else:
v = torch.tensor(value)[:, None, None] v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item() i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item()
...@@ -121,8 +119,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -121,8 +119,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2: if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels") raise ValueError("Need a batch of one hot labels")
output = inpt.clone() output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
return features.OneHotLabel.wrap_like(inpt, output) return features.OneHotLabel.wrap_like(inpt, output)
...@@ -136,8 +133,7 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -136,8 +133,7 @@ class RandomMixup(_BaseMixupCutmix):
expected_ndim = 5 if isinstance(inpt, features.Video) else 4 expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim: if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input") raise ValueError("The transform expects a batched input")
output = inpt.clone() output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
...@@ -243,11 +239,12 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -243,11 +239,12 @@ class SimpleCopyPaste(_RandomApplyTransform):
if blending: if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images: # Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
# Copy-paste masks: # Copy-paste masks:
masks = masks * (~paste_alpha_mask) masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0 non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks] masks = masks[non_all_zero_masks]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment