Unverified Commit 12adc542 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add video support on MixUp and CutMix (#6733)

* Add video support on MixUp and CutMix

* Switch back to roll

* Fix tests and mypy

* Another mypy fix
parent a3fe870b
...@@ -112,9 +112,12 @@ class TestSmoke: ...@@ -112,9 +112,12 @@ class TestSmoke:
( (
transform, transform,
[ [
dict(image=image, one_hot_label=one_hot_label) dict(inpt=inpt, one_hot_label=one_hot_label)
for image, one_hot_label in itertools.product( for inpt, one_hot_label in itertools.product(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
) )
], ],
......
...@@ -107,8 +107,11 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -107,8 +107,11 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): if not (
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") has_any(inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(inputs, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
...@@ -119,7 +122,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -119,7 +122,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
if inpt.ndim < 2: if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels") raise ValueError("Need a batch of one hot labels")
output = inpt.clone() output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
return features.OneHotLabel.wrap_like(inpt, output) return features.OneHotLabel.wrap_like(inpt, output)
...@@ -129,14 +132,15 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -129,14 +132,15 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"] lam = params["lam"]
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt): if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
if inpt.ndim < 4: expected_ndim = 5 if isinstance(inpt, features.Video) else 4
raise ValueError("Need a batch of images") if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.clone() output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
if isinstance(inpt, features.Image): if isinstance(inpt, (features.Image, features.Video)):
output = features.Image.wrap_like(inpt, output) output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
...@@ -169,17 +173,18 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -169,17 +173,18 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt): if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
box = params["box"] box = params["box"]
if inpt.ndim < 4: expected_ndim = 5 if isinstance(inpt, features.Video) else 4
raise ValueError("Need a batch of images") if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4) rolled = inpt.roll(1, 0)
output = inpt.clone() output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, features.Image): if isinstance(inpt, (features.Image, features.Video)):
output = features.Image.wrap_like(inpt, output) output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
......
...@@ -483,8 +483,8 @@ class AugMix(_AutoAugmentBase): ...@@ -483,8 +483,8 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4 expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims) batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment