Unverified Commit c6a715c7 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Ported SimpleCopyPaste transform (#6451)



* WIP

* [proto] Added SimpleCopyPaste transform

* Refactored and cleaned the implementation and added tests

* Fixing code

* Fixed code formatting issue

* Minor updates

* Fixed merge issue
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent d8025b9a
...@@ -1313,6 +1313,97 @@ class TestRandomShortestSize: ...@@ -1313,6 +1313,97 @@ class TestRandomShortestSize:
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)
def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, features.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]
with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
transform = transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]
images, targets = transform._extract_image_targets(flat_sample)
assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]
def test__copy_paste(self):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
target = {
"boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(masks),
"labels": features.Label(torch.tensor([1, 2])),
}
paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_target = {
"boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(paste_masks),
"labels": features.Label(torch.tensor([3, 4])),
}
transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection)
assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])
class TestFixedSizeCrop: class TestFixedSizeCrop:
def test__get_params(self, mocker): def test__get_params(self, mocker):
crop_size = (7, 7) crop_size = (7, 7)
......
...@@ -2,7 +2,7 @@ from . import functional # usort: skip ...@@ -2,7 +2,7 @@ from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._color import ( from ._color import (
ColorJitter, ColorJitter,
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, Tuple from typing import Any, Dict, List, Tuple
import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw from ._utils import has_any, is_simple_tensor, query_chw
...@@ -178,3 +183,187 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -178,3 +183,187 @@ class RandomCutmix(_BaseMixupCutmix):
return self._mixup_onehotlabel(inpt, lam_adjusted) return self._mixup_onehotlabel(inpt, lam_adjusted)
else: else:
return inpt return inpt
class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending
def _copy_paste(
self,
image: Any,
target: Dict[str, Any],
paste_image: Any,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[Any, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection])
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2. We need to investigate that.
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
return image, out_target
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
)
targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})
return images, targets
def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.SegmentationMask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
flat_sample, spec = tree_flatten(sample)
images, targets = self._extract_image_targets(flat_sample)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images, output_targets = [], []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)
output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_target)
# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets)
return tree_unflatten(flat_sample, spec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment