"vscode:/vscode.git/clone" did not exist on "e61138be19cce30b1cf9e16dd8c35bbdbb86530d"
Unverified Commit bbc1aac8 authored by Lezwon Castelino's avatar Lezwon Castelino Committed by GitHub
Browse files

Add SimpleCopyPaste augmentation (#5825)



* added simple POC

* added jitter and crop options

* added references

* moved simplecopypaste to detection module

* working POC for simple copy paste in detection

* added comments

* remove transforms from class
updated the labels
added gaussian blur

* removed loop for mask calculation

* replaced Gaussian blur with functional api

* added inplace operations

* added changes to accept tuples instead of tensors

* - make copy paste functional
- make only one copy of batch and target

* add inplace support within copy paste functional

* Updated code for copy-paste transform

* Fixed code formatting

* [skip ci] removed manual thresholding

* Replaced cropping by resizing data to paste

* Removed inplace arg (as useless) and put a check on iscrowd target

* code-formatting

* Updated copypaste op to make it torch scriptable
Added fallbacks to support LSJ

* Fixed flake8

* Updates according to the review
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 369317f4
...@@ -31,6 +31,8 @@ import utils ...@@ -31,6 +31,8 @@ import utils
from coco_utils import get_coco, get_coco_kp from coco_utils import get_coco, get_coco_kp
from engine import train_one_epoch, evaluate from engine import train_one_epoch, evaluate
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste
def get_dataset(name, image_set, transform, data_path): def get_dataset(name, image_set, transform, data_path):
...@@ -145,6 +147,13 @@ def get_args_parser(add_help=True): ...@@ -145,6 +147,13 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters # Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# Use CopyPaste augmentation training parameter
parser.add_argument(
"--use-copypaste",
action="store_true",
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)
return parser return parser
...@@ -180,8 +189,20 @@ def main(args): ...@@ -180,8 +189,20 @@ def main(args):
else: else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)
def copypaste_collate_fn(batch):
return copypaste(*utils.collate_fn(batch))
train_collate_fn = copypaste_collate_fn
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
) )
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
......
...@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Optional, Union ...@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Optional, Union
import torch import torch
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision import ops
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T, InterpolationMode from torchvision.transforms import transforms as T, InterpolationMode
...@@ -437,3 +438,157 @@ class RandomShortestSize(nn.Module): ...@@ -437,3 +438,157 @@ class RandomShortestSize(nn.Module):
) )
return image, target return image, target
def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target
# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)
paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
return image, out_target
class SimpleCopyPaste(torch.nn.Module):
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending
def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)
return output_images, output_targets
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment