Unverified Commit 12adc542 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add video support on MixUp and CutMix (#6733)

* Add video support on MixUp and CutMix

* Switch back to roll

* Fix tests and mypy

* Another mypy fix
parent a3fe870b
......@@ -112,9 +112,12 @@ class TestSmoke:
(
transform,
[
dict(image=image, one_hot_label=one_hot_label)
for image, one_hot_label in itertools.product(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
dict(inpt=inpt, one_hot_label=one_hot_label)
for inpt, one_hot_label in itertools.product(
itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
],
......
......@@ -107,8 +107,11 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if not (
has_any(inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(inputs, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
......@@ -119,7 +122,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
return features.OneHotLabel.wrap_like(inpt, output)
......@@ -129,14 +132,15 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output)
if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, features.OneHotLabel):
......@@ -169,17 +173,18 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4)
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, features.OneHotLabel):
......
......@@ -483,8 +483,8 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment