"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "35a969d297cba69110d175ee79c59312b9f49e1e"
Unverified Commit 585b3b13 authored by Federico Pozzi's avatar Federico Pozzi Committed by GitHub
Browse files

refactor: port RandomVerticalFlip to prototype API (#5524) (#5633)


Co-authored-by: default avatarFederico Pozzi <federico.pozzi@argo.vision>
parent eb6e3915
...@@ -243,3 +243,56 @@ class TestRandomHorizontalFlip: ...@@ -243,3 +243,56 @@ class TestRandomHorizontalFlip:
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.image_size == expected.image_size assert actual.image_size == expected.image_size
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)
return input, expected if p == 1 else input
def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input)
assert_equal(expected, actual)
def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(to_pil_image(input))
assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.Image(input))
assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.SegmentationMask(input))
assert_equal(features.SegmentationMask(expected), actual)
def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
...@@ -15,6 +15,7 @@ from ._geometry import ( ...@@ -15,6 +15,7 @@ from ._geometry import (
TenCrop, TenCrop,
BatchMultiCrop, BatchMultiCrop,
RandomHorizontalFlip, RandomHorizontalFlip,
RandomVerticalFlip,
Pad, Pad,
RandomZoomOut, RandomZoomOut,
) )
......
...@@ -45,6 +45,36 @@ class RandomHorizontalFlip(Transform): ...@@ -45,6 +45,36 @@ class RandomHorizontalFlip(Transform):
return input return input
class RandomVerticalFlip(Transform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) > self.p:
return sample
return super().forward(sample)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.vertical_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.vertical_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.vertical_flip_image_pil(input)
elif is_simple_tensor(input):
return F.vertical_flip_image_tensor(input)
else:
return input
class Resize(Transform): class Resize(Transform):
def __init__( def __init__(
self, self,
......
...@@ -63,6 +63,8 @@ from ._geometry import ( ...@@ -63,6 +63,8 @@ from ._geometry import (
perspective_image_pil, perspective_image_pil,
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_image_pil, vertical_flip_image_pil,
vertical_flip_bounding_box,
vertical_flip_segmentation_mask,
five_crop_image_tensor, five_crop_image_tensor,
five_crop_image_pil, five_crop_image_pil,
ten_crop_image_tensor, ten_crop_image_tensor,
......
...@@ -81,6 +81,26 @@ vertical_flip_image_tensor = _FT.vflip ...@@ -81,6 +81,26 @@ vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip vertical_flip_image_pil = _FP.vflip
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(segmentation_mask)
def vertical_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
def _affine_parse_args( def _affine_parse_args(
angle: float, angle: float,
translate: List[float], translate: List[float],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment