Unverified Commit 41035520 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port convert_bounding_box_format tests (#7933)

parent 1f94320d
......@@ -522,33 +522,6 @@ class TestClampBoundingBoxes:
F.clamp_bounding_boxes(tv_tensor, **metadata)
class TestConvertFormatBoundingBoxes:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_multiple_bounding_boxes()), None),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_bounding_box_format(inpt, old_format)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)
def test_tv_tensor_explicit_metadata(self):
tv_tensor = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_bounding_box_format(
tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `transforms_v2_kernel_infos.py`
......
import contextlib
import decimal
import functools
import inspect
import itertools
import math
import pickle
import re
......@@ -12,6 +14,8 @@ import PIL.Image
import pytest
import torch
import torchvision.ops
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_equal,
......@@ -138,7 +142,6 @@ def check_kernel(
check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
**kwargs,
):
initial_input_version = input._version
......@@ -151,7 +154,7 @@ def check_kernel(
# check that no inplace operation happened
assert input._version == initial_input_version
if expect_same_dtype:
if kernel not in {F.to_dtype_image, F.to_dtype_video}:
assert output.dtype == input.dtype
assert output.device == input.device
......@@ -187,7 +190,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
assert isinstance(output, type(input))
if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and functional is not F.convert_bounding_box_format:
assert output.format == input.format
if check_scripted_smoke:
......@@ -264,7 +267,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
output = transform(input)
assert isinstance(output, type(input))
if isinstance(input, tv_tensors.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
assert output.format == input.format
if check_v1_compatibility:
......@@ -1743,7 +1746,6 @@ class TestToDtype:
check_kernel(
kernel,
make_input(dtype=input_dtype, device=device),
expect_same_dtype=input_dtype is output_dtype,
dtype=output_dtype,
scale=scale,
)
......@@ -3009,3 +3011,102 @@ class TestAutoAugmentTransforms:
def test_aug_mix_severity_error(self, severity):
with pytest.raises(ValueError, match="severity must be between"):
transforms.AugMix(severity=severity)
class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel(self, old_format, new_format):
check_kernel(
F.convert_bounding_box_format,
make_bounding_boxes(format=old_format),
new_format=new_format,
old_format=old_format,
)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("inplace", [False, True])
def test_kernel_noop(self, format, inplace):
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
input_version = input._version
output = F.convert_bounding_box_format(input, old_format=format, new_format=format, inplace=inplace)
assert output is input
assert output.data_ptr() == input.data_ptr()
assert output._version == input_version
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel_inplace(self, old_format, new_format):
input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor)
input_version = input._version
output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input
output_inplace = F.convert_bounding_box_format(
input, old_format=old_format, new_format=new_format, inplace=True
)
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input
assert_equal(output_inplace, output_out_of_place)
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_functional(self, old_format, new_format):
check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("format_type", ["enum", "str"])
def test_transform(self, old_format, new_format, format_type):
check_transform(
transforms.ConvertBoundingBoxFormat(new_format.name if format_type == "str" else new_format),
make_bounding_boxes(format=old_format),
)
def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
return tv_tensors.wrap(
torchvision.ops.box_convert(
bounding_boxes.as_subclass(torch.Tensor),
in_fmt=bounding_boxes.format.name.lower(),
out_fmt=new_format.name.lower(),
).to(bounding_boxes.dtype),
like=bounding_boxes,
format=new_format,
)
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn_type", ["functional", "transform"])
def test_correctness(self, old_format, new_format, dtype, device, fn_type):
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)
if fn_type == "functional":
fn = functools.partial(F.convert_bounding_box_format, new_format=new_format)
else:
fn = transforms.ConvertBoundingBoxFormat(format=new_format)
actual = fn(bounding_boxes)
expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format)
assert_equal(actual, expected)
def test_errors(self):
input_tv_tensor = make_bounding_boxes()
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)
for input in [input_tv_tensor, input_pure_tensor]:
with pytest.raises(TypeError, match="missing 1 required argument: 'new_format'"):
F.convert_bounding_box_format(input)
with pytest.raises(ValueError, match="`old_format` has to be passed"):
F.convert_bounding_box_format(input_pure_tensor, new_format=input_tv_tensor.format)
with pytest.raises(ValueError, match="`old_format` must not be passed"):
F.convert_bounding_box_format(
input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
)
......@@ -315,11 +315,4 @@ DISPATCHER_INFOS = [
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.convert_bounding_box_format,
kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_tv_tensor,
],
),
]
......@@ -5,7 +5,6 @@ import numpy as np
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
......@@ -227,38 +226,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
).reshape(bounding_boxes.shape)
def sample_inputs_convert_bounding_box_format():
formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)
def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
),
)
_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment