Unverified Commit 6279089a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix MixUp and CutMix (#6464)

* fix MixUp and CutMix

* improve error message
parent acf30e98
...@@ -99,10 +99,8 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -99,10 +99,8 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0] sample = inpts if len(inpts) > 1 else inpts[0]
if not ( if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)):
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel) raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels." f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
...@@ -123,12 +121,16 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -123,12 +121,16 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"] lam = params["lam"]
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
if inpt.ndim < 4: if inpt.ndim < 4:
raise ValueError("Need a batch of images") raise ValueError("Need a batch of images")
output = inpt.clone() output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return features.Image.new_like(inpt, output)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam) return self._mixup_onehotlabel(inpt, lam)
else: else:
...@@ -159,7 +161,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -159,7 +161,7 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
box = params["box"] box = params["box"]
if inpt.ndim < 4: if inpt.ndim < 4:
raise ValueError("Need a batch of images") raise ValueError("Need a batch of images")
...@@ -167,7 +169,11 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -167,7 +169,11 @@ class RandomCutmix(_BaseMixupCutmix):
image_rolled = inpt.roll(1, -4) image_rolled = inpt.roll(1, -4)
output = inpt.clone() output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return features.Image.new_like(inpt, output)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
lam_adjusted = params["lam_adjusted"] lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted) return self._mixup_onehotlabel(inpt, lam_adjusted)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment