Unverified Commit 25c8a3a2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port horizontal flip tests (#7703)

parent c9ac3a5b
...@@ -406,59 +406,6 @@ def test_simple_tensor_heuristic(flat_inputs): ...@@ -406,59 +406,6 @@ def test_simple_tensor_heuristic(flat_inputs):
assert transform.was_applied(output, input) assert transform.was_applied(output, input)
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
return input, expected if p == 1 else input
def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input)
assert_equal(expected, actual)
def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(to_pil_image(input))
assert_equal(expected, pil_to_tensor(actual))
def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(datapoints.Image(input))
assert_equal(datapoints.Image(expected), actual)
def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(datapoints.Mask(input))
assert_equal(datapoints.Mask(expected), actual)
def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.spatial_size == expected.spatial_size
@pytest.mark.parametrize("p", [0.0, 1.0]) @pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip: class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32): def input_expected_image_tensor(self, p, dtype=torch.float32):
......
...@@ -295,9 +295,9 @@ def check_transform(transform_cls, input, *args, **kwargs): ...@@ -295,9 +295,9 @@ def check_transform(transform_cls, input, *args, **kwargs):
_check_transform_v1_compatibility(transform, input) _check_transform_v1_compatibility(transform, input)
def transform_cls_to_functional(transform_cls): def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs): def wrapper(input, *args, **kwargs):
transform = transform_cls(*args, **kwargs) transform = transform_cls(*args, **transform_specific_kwargs, **kwargs)
return transform(input) return transform(input)
wrapper.__name__ = transform_cls.__name__ wrapper.__name__ = transform_cls.__name__
...@@ -321,14 +321,14 @@ def assert_warns_antialias_default_value(): ...@@ -321,14 +321,14 @@ def assert_warns_antialias_default_value():
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
def transform(bbox, affine_matrix_, format_, spatial_size_): def transform(bbox):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
bbox = bbox.float() bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box( bbox_xyxy = F.convert_format_bounding_box(
bbox.as_subclass(torch.Tensor), bbox.as_subclass(torch.Tensor),
old_format=format_, old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
inplace=True, inplace=True,
) )
...@@ -340,7 +340,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, ...@@ -340,7 +340,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
] ]
) )
transformed_points = np.matmul(points, affine_matrix_.T) transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = torch.tensor( out_bbox = torch.tensor(
[ [
np.min(transformed_points[:, 0]).item(), np.min(transformed_points[:, 0]).item(),
...@@ -351,23 +351,14 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, ...@@ -351,23 +351,14 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
dtype=bbox_xyxy.dtype, dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_format_bounding_box( out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_) out_bbox = F.clamp_bounding_box(out_bbox, format=format, spatial_size=spatial_size)
out_bbox = out_bbox.to(dtype=in_dtype) out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox return out_bbox
if bounding_box.ndim < 2: return torch.stack([transform(b) for b in bounding_box.reshape(-1, 4).unbind()]).reshape(bounding_box.shape)
bounding_box = [bounding_box]
expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
return expected_bboxes
class TestResize: class TestResize:
...@@ -493,7 +484,7 @@ class TestResize: ...@@ -493,7 +484,7 @@ class TestResize:
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_type_and_kernel", ("input_type", "kernel"),
[ [
(torch.Tensor, F.resize_image_tensor), (torch.Tensor, F.resize_image_tensor),
(PIL.Image.Image, F.resize_image_pil), (PIL.Image.Image, F.resize_image_pil),
...@@ -503,8 +494,7 @@ class TestResize: ...@@ -503,8 +494,7 @@ class TestResize:
(datapoints.Video, F.resize_video), (datapoints.Video, F.resize_video),
], ],
) )
def test_dispatcher(self, size, input_type_and_kernel): def test_dispatcher(self, size, input_type, kernel):
input_type, kernel = input_type_and_kernel
check_dispatcher( check_dispatcher(
F.resize, F.resize,
kernel, kernel,
...@@ -726,3 +716,147 @@ class TestResize: ...@@ -726,3 +716,147 @@ class TestResize:
output = F.resize(input, size=size, max_size=max_size, antialias=True) output = F.resize(input, size=size, max_size=max_size, antialias=True)
assert max(F.get_spatial_size(output)) == max_size assert max(F.get_spatial_size(output)) == max_size
class TestHorizontalFlip:
def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs):
if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}:
input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
if input_type is torch.Tensor:
input = input.as_subclass(torch.Tensor)
elif input_type is PIL.Image.Image:
input = F.to_image_pil(input)
elif input_type is datapoints.BoundingBox:
kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY)
input = make_bounding_box(
dtype=dtype or torch.float32,
device=device,
spatial_size=spatial_size,
**kwargs,
)
elif input_type is datapoints.Mask:
input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
elif input_type is datapoints.Video:
input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
return input
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.horizontal_flip_image_tensor, self._make_input(torch.Tensor))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box(self, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
check_kernel(
F.horizontal_flip_bounding_box,
bounding_box,
format=format,
spatial_size=bounding_box.spatial_size,
)
@pytest.mark.parametrize(
"dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)]
)
def test_kernel_mask(self, dtype_and_make_mask):
dtype, make_mask = dtype_and_make_mask
check_kernel(F.horizontal_flip_mask, make_mask(dtype=dtype))
def test_kernel_video(self):
check_kernel(F.horizontal_flip_video, self._make_input(datapoints.Video))
@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.horizontal_flip_image_tensor),
(PIL.Image.Image, F.horizontal_flip_image_pil),
(datapoints.Image, F.horizontal_flip_image_tensor),
(datapoints.BoundingBox, F.horizontal_flip_bounding_box),
(datapoints.Mask, F.horizontal_flip_mask),
(datapoints.Video, F.horizontal_flip_video),
],
)
def test_dispatcher(self, kernel, input_type):
check_dispatcher(F.horizontal_flip, kernel, self._make_input(input_type))
@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.resize_image_tensor),
(PIL.Image.Image, F.resize_image_pil),
(datapoints.Image, F.resize_image_tensor),
(datapoints.BoundingBox, F.resize_bounding_box),
(datapoints.Mask, F.resize_mask),
(datapoints.Video, F.resize_video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, input_type, device):
input = self._make_input(input_type, device=device)
check_transform(transforms.RandomHorizontalFlip, input, p=1)
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_image_correctness(self, fn):
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")
actual = fn(image)
expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image)))
torch.testing.assert_close(actual, expected)
def _reference_horizontal_flip_bounding_box(self, bounding_box):
affine_matrix = np.array(
[
[-1, 0, bounding_box.spatial_size[1]],
[0, 1, 0],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box,
format=bounding_box.format,
spatial_size=bounding_box.spatial_size,
affine_matrix=affine_matrix,
)
return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_bounding_box_correctness(self, format, fn):
bounding_box = self._make_input(datapoints.BoundingBox)
actual = fn(bounding_box)
expected = self._reference_horizontal_flip_bounding_box(bounding_box)
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, input_type, device):
input = self._make_input(input_type, device=device)
transform = transforms.RandomHorizontalFlip(p=0)
output = transform(input)
assert_equal(output, input)
...@@ -138,16 +138,6 @@ xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil( ...@@ -138,16 +138,6 @@ xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
DISPATCHER_INFOS = [ DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
kernels={
datapoints.Image: F.horizontal_flip_image_tensor,
datapoints.Video: F.horizontal_flip_video,
datapoints.BoundingBox: F.horizontal_flip_bounding_box,
datapoints.Mask: F.horizontal_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
),
DispatcherInfo( DispatcherInfo(
F.affine, F.affine,
kernels={ kernels={
......
...@@ -156,88 +156,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ...@@ -156,88 +156,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
KERNEL_INFOS = [] KERNEL_INFOS = []
def sample_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader)
def reference_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_horizontal_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)
def sample_inputs_horizontal_flip_mask():
for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
def sample_inputs_horizontal_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)
def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[-1, 0, spatial_size[1]],
[0, 1, 0],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes
def reference_inputs_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
KERNEL_INFOS.extend(
[
KernelInfo(
F.horizontal_flip_image_tensor,
kernel_name="horizontal_flip_image_tensor",
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
float32_vs_uint8=True,
),
KernelInfo(
F.horizontal_flip_bounding_box,
sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box,
reference_fn=reference_horizontal_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
),
KernelInfo(
F.horizontal_flip_mask,
sample_inputs_fn=sample_inputs_horizontal_flip_mask,
),
KernelInfo(
F.horizontal_flip_video,
sample_inputs_fn=sample_inputs_horizontal_flip_video,
),
]
)
_AFFINE_KWARGS = combinations_grid( _AFFINE_KWARGS = combinations_grid(
angle=[-87, 15, 90], angle=[-87, 15, 90],
translate=[(5, 5), (-5, -5)], translate=[(5, 5), (-5, -5)],
...@@ -573,6 +491,15 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): ...@@ -573,6 +491,15 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
return expected_bboxes return expected_bboxes
def reference_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -587,7 +514,7 @@ KERNEL_INFOS.extend( ...@@ -587,7 +514,7 @@ KERNEL_INFOS.extend(
F.vertical_flip_bounding_box, F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box, sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box, reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box, reference_inputs_fn=reference_inputs_vertical_flip_bounding_box,
), ),
KernelInfo( KernelInfo(
F.vertical_flip_mask, F.vertical_flip_mask,
......
...@@ -43,7 +43,8 @@ def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -43,7 +43,8 @@ def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
horizontal_flip_image_pil = _FP.hflip def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment