"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e3efbc2d9094685dd2d4ae143853941f82f167af"
Unverified Commit 7bb5e41b authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Scale Jitter transform for detection (#5435)

* Adding Scale Jitter in references.

* Update documentation.

* Address review comments.
parent 48a61df2
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T from torchvision.transforms import transforms as T, InterpolationMode
def _flip_coco_person_keypoints(kps, width): def _flip_coco_person_keypoints(kps, width):
...@@ -282,3 +282,52 @@ class RandomPhotometricDistort(nn.Module): ...@@ -282,3 +282,52 @@ class RandomPhotometricDistort(nn.Module):
image = F.to_pil_image(image) image = F.to_pil_image(image)
return image, target return image, target
class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
new_height = int(self.target_size[0] * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
if target is not None:
target["boxes"] *= r
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)
return image, target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment