Unverified Commit 71968bc4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port vertical flip (#7712)

parent 0e496155
...@@ -29,7 +29,7 @@ from common_utils import ( ...@@ -29,7 +29,7 @@ from common_utils import (
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import datapoints
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
...@@ -406,59 +406,6 @@ def test_simple_tensor_heuristic(flat_inputs): ...@@ -406,59 +406,6 @@ def test_simple_tensor_heuristic(flat_inputs):
assert transform.was_applied(output, input) assert transform.was_applied(output, input)
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)
return input, expected if p == 1 else input
def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input)
assert_equal(expected, actual)
def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(to_pil_image(input))
assert_equal(expected, pil_to_tensor(actual))
def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(datapoints.Image(input))
assert_equal(datapoints.Image(expected), actual)
def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(datapoints.Mask(input))
assert_equal(datapoints.Mask(expected), actual)
def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.spatial_size == expected.spatial_size
class TestPad: class TestPad:
def test_assertions(self): def test_assertions(self):
with pytest.raises(TypeError, match="Got inappropriate padding arg"): with pytest.raises(TypeError, match="Got inappropriate padding arg"):
......
...@@ -842,7 +842,7 @@ class TestHorizontalFlip: ...@@ -842,7 +842,7 @@ class TestHorizontalFlip:
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
) )
def test_bounding_box_correctness(self, format, fn): def test_bounding_box_correctness(self, format, fn):
bounding_box = self._make_input(datapoints.BoundingBox) bounding_box = self._make_input(datapoints.BoundingBox, format=format)
actual = fn(bounding_box) actual = fn(bounding_box)
expected = self._reference_horizontal_flip_bounding_box(bounding_box) expected = self._reference_horizontal_flip_bounding_box(bounding_box)
...@@ -1025,12 +1025,10 @@ class TestAffine: ...@@ -1025,12 +1025,10 @@ class TestAffine:
@pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) @pytest.mark.parametrize("mask_type", ["segmentation", "detection"])
def test_kernel_mask(self, mask_type): def test_kernel_mask(self, mask_type):
check_kernel( self._check_kernel(F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type))
F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type), **self._MINIMAL_AFFINE_KWARGS
)
def test_kernel_video(self): def test_kernel_video(self):
check_kernel(F.affine_video, self._make_input(datapoints.Video), **self._MINIMAL_AFFINE_KWARGS) self._check_kernel(F.affine_video, self._make_input(datapoints.Video))
@pytest.mark.parametrize( @pytest.mark.parametrize(
("input_type", "kernel"), ("input_type", "kernel"),
...@@ -1301,3 +1299,143 @@ class TestAffine: ...@@ -1301,3 +1299,143 @@ class TestAffine:
def test_transform_unknown_fill_error(self): def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"): with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill") transforms.RandomAffine(degrees=0, fill="fill")
class TestVerticalFlip:
def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs):
if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}:
input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
if input_type is torch.Tensor:
input = input.as_subclass(torch.Tensor)
elif input_type is PIL.Image.Image:
input = F.to_image_pil(input)
elif input_type is datapoints.BoundingBox:
kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY)
input = make_bounding_box(
dtype=dtype or torch.float32,
device=device,
spatial_size=spatial_size,
**kwargs,
)
elif input_type is datapoints.Mask:
input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
elif input_type is datapoints.Video:
input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
return input
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.vertical_flip_image_tensor, self._make_input(torch.Tensor, dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box(self, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
check_kernel(
F.vertical_flip_bounding_box,
bounding_box,
format=format,
spatial_size=bounding_box.spatial_size,
)
@pytest.mark.parametrize(
"dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)]
)
def test_kernel_mask(self, dtype_and_make_mask):
dtype, make_mask = dtype_and_make_mask
check_kernel(F.vertical_flip_mask, make_mask(dtype=dtype))
def test_kernel_video(self):
check_kernel(F.vertical_flip_video, self._make_input(datapoints.Video))
@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.vertical_flip_image_tensor),
(PIL.Image.Image, F.vertical_flip_image_pil),
(datapoints.Image, F.vertical_flip_image_tensor),
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
(datapoints.Mask, F.vertical_flip_mask),
(datapoints.Video, F.vertical_flip_video),
],
)
def test_dispatcher(self, kernel, input_type):
check_dispatcher(F.vertical_flip, kernel, self._make_input(input_type))
@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.vertical_flip_image_tensor),
(PIL.Image.Image, F.vertical_flip_image_pil),
(datapoints.Image, F.vertical_flip_image_tensor),
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
(datapoints.Mask, F.vertical_flip_mask),
(datapoints.Video, F.vertical_flip_video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, input_type, device):
input = self._make_input(input_type, device=device)
check_transform(transforms.RandomVerticalFlip, input, p=1)
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")
actual = fn(image)
expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image)))
torch.testing.assert_close(actual, expected)
def _reference_vertical_flip_bounding_box(self, bounding_box):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, bounding_box.spatial_size[0]],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box,
format=bounding_box.format,
spatial_size=bounding_box.spatial_size,
affine_matrix=affine_matrix,
)
return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_bounding_box_correctness(self, format, fn):
bounding_box = self._make_input(datapoints.BoundingBox, format=format)
actual = fn(bounding_box)
expected = self._reference_vertical_flip_bounding_box(bounding_box)
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, input_type, device):
input = self._make_input(input_type, device=device)
transform = transforms.RandomVerticalFlip(p=0)
output = transform(input)
assert_equal(output, input)
...@@ -138,16 +138,6 @@ xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil( ...@@ -138,16 +138,6 @@ xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
DISPATCHER_INFOS = [ DISPATCHER_INFOS = [
DispatcherInfo(
F.vertical_flip,
kernels={
datapoints.Image: F.vertical_flip_image_tensor,
datapoints.Video: F.vertical_flip_video,
datapoints.BoundingBox: F.vertical_flip_bounding_box,
datapoints.Mask: F.vertical_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
),
DispatcherInfo( DispatcherInfo(
F.rotate, F.rotate,
kernels={ kernels={
......
...@@ -264,87 +264,6 @@ KERNEL_INFOS.append( ...@@ -264,87 +264,6 @@ KERNEL_INFOS.append(
) )
def sample_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader)
def reference_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)
def sample_inputs_vertical_flip_mask():
for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_vertical_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)
def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes
def reference_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
KERNEL_INFOS.extend(
[
KernelInfo(
F.vertical_flip_image_tensor,
kernel_name="vertical_flip_image_tensor",
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
float32_vs_uint8=True,
),
KernelInfo(
F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_vertical_flip_bounding_box,
),
KernelInfo(
F.vertical_flip_mask,
sample_inputs_fn=sample_inputs_vertical_flip_mask,
),
KernelInfo(
F.vertical_flip_video,
sample_inputs_fn=sample_inputs_vertical_flip_video,
),
]
)
_ROTATE_ANGLES = [-87, 15, 90] _ROTATE_ANGLES = [-87, 15, 90]
......
...@@ -93,7 +93,8 @@ def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -93,7 +93,8 @@ def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2) return image.flip(-2)
vertical_flip_image_pil = _FP.vflip def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment