"...source/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "bfd88cad891c670b146bb678799fc2f531f75c39"
Unverified Commit 7be2f55b authored by Federico Pozzi's avatar Federico Pozzi Committed by GitHub
Browse files

port RandomHorizontalFlip to prototype API (#5563)



* refactor: port RandomHorizontalFlip to prototype API (#5523)

* refactor: merge HorizontalFlip and RandomHorizontalFlip

Add unit tests for RandomHorizontalFlip

* test: RandomHorizontalFlip with p=0

* refactor: remove type annotations from tests

* refactor: improve tests

* Update test/test_prototype_transforms.py
Co-authored-by: default avatarFederico Pozzi <federico.pozzi@argo.vision>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 60132302
...@@ -2,9 +2,10 @@ import itertools ...@@ -2,9 +2,10 @@ import itertools
import pytest import pytest
import torch import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image, pil_to_tensor
def make_vanilla_tensor_images(*args, **kwargs): def make_vanilla_tensor_images(*args, **kwargs):
...@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms): ...@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms):
class TestSmoke: class TestSmoke:
@parametrize_from_transforms( @parametrize_from_transforms(
transforms.RandomErasing(p=1.0), transforms.RandomErasing(p=1.0),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]), transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]), transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(), transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
) )
def test_common(self, transform, input): def test_common(self, transform, input):
transform(input) transform(input)
...@@ -188,3 +189,56 @@ class TestSmoke: ...@@ -188,3 +189,56 @@ class TestSmoke:
) )
def test_convert_image_color_space(self, transform, input): def test_convert_image_color_space(self, transform, input):
transform(input) transform(input)
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
return input, expected if p == 1 else input
def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input)
assert_equal(expected, actual)
def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(to_pil_image(input))
assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.Image(input))
assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.SegmentationMask(input))
assert_equal(features.SegmentationMask(expected), actual)
def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
...@@ -8,13 +8,13 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix ...@@ -8,13 +8,13 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import ( from ._geometry import (
HorizontalFlip,
Resize, Resize,
CenterCrop, CenterCrop,
RandomResizedCrop, RandomResizedCrop,
FiveCrop, FiveCrop,
TenCrop, TenCrop,
BatchMultiCrop, BatchMultiCrop,
RandomHorizontalFlip,
RandomZoomOut, RandomZoomOut,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
...@@ -13,11 +13,25 @@ from torchvision.transforms.transforms import _setup_size, _interpolation_modes_ ...@@ -13,11 +13,25 @@ from torchvision.transforms.transforms import _setup_size, _interpolation_modes_
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class HorizontalFlip(Transform): class RandomHorizontalFlip(Transform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input) output = F.horizontal_flip_image_tensor(input)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.horizontal_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox): elif isinstance(input, features.BoundingBox):
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output) return features.BoundingBox.new_like(input, output)
......
...@@ -40,6 +40,7 @@ from ._geometry import ( ...@@ -40,6 +40,7 @@ from ._geometry import (
horizontal_flip_bounding_box, horizontal_flip_bounding_box,
horizontal_flip_image_tensor, horizontal_flip_image_tensor,
horizontal_flip_image_pil, horizontal_flip_image_pil,
horizontal_flip_segmentation_mask,
resize_bounding_box, resize_bounding_box,
resize_image_tensor, resize_image_tensor,
resize_image_pil, resize_image_pil,
......
...@@ -15,6 +15,10 @@ horizontal_flip_image_tensor = _FT.hflip ...@@ -15,6 +15,10 @@ horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(segmentation_mask)
def horizontal_flip_bounding_box( def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment