Unverified Commit f3c89cc6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove cutmix and mixup from prototype (#7787)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent cab9fba8
...@@ -12,13 +12,10 @@ from common_utils import ( ...@@ -12,13 +12,10 @@ from common_utils import (
make_bounding_box, make_bounding_box,
make_detection_mask, make_detection_mask,
make_image, make_image,
make_images,
make_segmentation_mask,
make_video, make_video,
make_videos,
) )
from prototype_common_utils import make_label, make_one_hot_labels from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
...@@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs): ...@@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs):
) )
@parametrize(
[
(
transform,
[
dict(inpt=inpt, one_hot_label=one_hot_label)
for inpt, one_hot_label in itertools.product(
itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
],
)
for transform in [
transforms.RandomMixUp(alpha=1.0),
transforms.RandomCutMix(alpha=1.0),
]
]
)
def test_mixup_cutmix(transform, input):
transform(input)
input_copy = dict(input)
input_copy["path"] = "/path/to/somewhere"
input_copy["num"] = 1234
transform(input_copy)
# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg):
transform(input_copy)
class TestSimpleCopyPaste: class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type): def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image: if image_type == PIL.Image.Image:
......
from ._presets import StereoMatching # usort: skip from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste from ._augment import SimpleCopyPaste
from ._geometry import FixedSizeCrop from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot from ._type_conversion import LabelToOneHot
import math
from typing import Any, cast, Dict, List, Optional, Tuple, Union from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -9,100 +8,8 @@ from torchvision.ops import masks_to_boxes ...@@ -9,100 +8,8 @@ from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._transform import _RandomApplyTransform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size from torchvision.transforms.v2.utils import is_simple_tensor
class _BaseMixUpCutMix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, proto_datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixUp(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt
class RandomCutMix(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_size(flat_inputs)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
box = params["box"]
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt
class SimpleCopyPaste(Transform): class SimpleCopyPaste(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment