You need to sign in or sign up before continuing.
Unverified Commit a63046ce authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only use plain tensors in kernel tests (#7230)

parent acabaf80
...@@ -237,13 +237,6 @@ class TensorLoader: ...@@ -237,13 +237,6 @@ class TensorLoader:
def load(self, device): def load(self, device):
return self.fn(self.shape, self.dtype, device) return self.fn(self.shape, self.dtype, device)
def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
......
...@@ -64,13 +64,21 @@ class DispatcherInfo(InfoBase): ...@@ -64,13 +64,21 @@ class DispatcherInfo(InfoBase):
if not filter_metadata: if not filter_metadata:
yield from sample_inputs yield from sample_inputs
else: return
for args_kwargs in sample_inputs:
for attribute in datapoint_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
yield args_kwargs import itertools
for args_kwargs in sample_inputs:
for name in itertools.chain(
datapoint_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
yield args_kwargs
def xfail_jit(reason, *, condition=None): def xfail_jit(reason, *, condition=None):
...@@ -458,4 +466,18 @@ DISPATCHER_INFOS = [ ...@@ -458,4 +466,18 @@ DISPATCHER_INFOS = [
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo(
F.clamp_bounding_box,
kernels={datapoints.BoundingBox: F.clamp_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.convert_format_bounding_box,
kernels={datapoints.BoundingBox: F.convert_format_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
] ]
...@@ -12,7 +12,6 @@ import torchvision.prototype.transforms.functional as F ...@@ -12,7 +12,6 @@ import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
BoundingBoxLoader,
get_num_channels, get_num_channels,
ImageLoader, ImageLoader,
InfoBase, InfoBase,
...@@ -337,7 +336,6 @@ def sample_inputs_resize_video(): ...@@ -337,7 +336,6 @@ def sample_inputs_resize_video():
def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None): def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None):
old_height, old_width = spatial_size old_height, old_width = spatial_size
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size) new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)
...@@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size= ...@@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
) )
expected_bboxes = reference_affine_bounding_box_helper( expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=bounding_box.format, affine_matrix=affine_matrix bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix
) )
return expected_bboxes, (new_height, new_width) return expected_bboxes, (new_height, new_width)
def reference_inputs_resize_bounding_box(): def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))): for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for size in _get_resize_sizes(bounding_box_loader.spatial_size): for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
...@@ -668,8 +668,7 @@ KERNEL_INFOS.extend( ...@@ -668,8 +668,7 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, new_format=new_format) yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
...@@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format): ...@@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def reference_inputs_convert_format_bounding_box(): def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) != 2: if len(args_kwargs.args[0].shape) == 2:
continue yield args_kwargs
(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
KERNEL_INFOS.append( KERNEL_INFOS.append(
...@@ -697,18 +690,6 @@ KERNEL_INFOS.append( ...@@ -697,18 +690,6 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
), ),
) )
...@@ -2049,10 +2030,8 @@ KERNEL_INFOS.extend( ...@@ -2049,10 +2030,8 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box(): def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader.unwrap(), bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_box_loader.spatial_size,
) )
...@@ -2063,19 +2042,6 @@ KERNEL_INFOS.append( ...@@ -2063,19 +2042,6 @@ KERNEL_INFOS.append(
F.clamp_bounding_box, F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
) )
) )
......
...@@ -121,8 +121,8 @@ class TestKernels: ...@@ -121,8 +121,8 @@ class TestKernels:
def test_logging(self, spy_on, info, args_kwargs, device): def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once) spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs) info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}") spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
...@@ -134,6 +134,7 @@ class TestKernels: ...@@ -134,6 +134,7 @@ class TestKernels:
kernel_scripted = script(kernel_eager) kernel_scripted = script(kernel_eager)
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
actual = kernel_scripted(input, *other_args, **kwargs) actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs) expected = kernel_eager(input, *other_args, **kwargs)
...@@ -155,14 +156,12 @@ class TestKernels: ...@@ -155,14 +156,12 @@ class TestKernels:
if batched_tensor.ndim == data_dims: if batched_tensor.ndim == data_dims:
return batch return batch
unbatcheds = [] return [
for unbatched in ( self._unbatch(unbatched, data_dims=data_dims)
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] for unbatched in (
): batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
if isinstance(batch, datapoints._datapoint.Datapoint): )
unbatched = type(batch).wrap_like(batch, unbatched) ]
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -195,6 +194,7 @@ class TestKernels: ...@@ -195,6 +194,7 @@ class TestKernels:
elif not all(batched_input.shape[:-data_dims]): elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.") pytest.skip("Input has a degenerate batch shape.")
batched_input = batched_input.as_subclass(torch.Tensor)
batched_output = info.kernel(batched_input, *other_args, **kwargs) batched_output = info.kernel(batched_input, *other_args, **kwargs)
actual = self._unbatch(batched_output, data_dims=data_dims) actual = self._unbatch(batched_output, data_dims=data_dims)
...@@ -212,6 +212,7 @@ class TestKernels: ...@@ -212,6 +212,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device): def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
if input.numel() == 0: if input.numel() == 0:
pytest.skip("The input has a degenerate shape.") pytest.skip("The input has a degenerate shape.")
...@@ -225,6 +226,7 @@ class TestKernels: ...@@ -225,6 +226,7 @@ class TestKernels:
@needs_cuda @needs_cuda
def test_cuda_vs_cpu(self, test_id, info, args_kwargs): def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu") (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cpu = input_cpu.as_subclass(torch.Tensor)
input_cuda = input_cpu.to("cuda") input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs) output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
...@@ -242,6 +244,7 @@ class TestKernels: ...@@ -242,6 +244,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device): def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
output = info.kernel(input, *other_args, **kwargs) output = info.kernel(input, *other_args, **kwargs)
# Most kernels just return a tensor, but some also return some additional metadata # Most kernels just return a tensor, but some also return some additional metadata
...@@ -254,6 +257,7 @@ class TestKernels: ...@@ -254,6 +257,7 @@ class TestKernels:
@reference_inputs @reference_inputs
def test_against_reference(self, test_id, info, args_kwargs): def test_against_reference(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
actual = info.kernel(input, *other_args, **kwargs) actual = info.kernel(input, *other_args, **kwargs)
expected = info.reference_fn(input, *other_args, **kwargs) expected = info.reference_fn(input, *other_args, **kwargs)
...@@ -271,6 +275,7 @@ class TestKernels: ...@@ -271,6 +275,7 @@ class TestKernels:
) )
def test_float32_vs_uint8(self, test_id, info, args_kwargs): def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
if input.dtype != torch.uint8: if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.") pytest.skip(f"Input dtype is {input.dtype}.")
...@@ -341,7 +346,6 @@ class TestDispatchers: ...@@ -341,7 +346,6 @@ class TestDispatchers:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dispatcher", "dispatcher",
[ [
F.clamp_bounding_box,
F.get_dimensions, F.get_dimensions,
F.get_image_num_channels, F.get_image_num_channels,
F.get_image_size, F.get_image_size,
...@@ -647,21 +651,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -647,21 +651,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_bounding_box_on_fixed_input(device): def test_correctness_affine_bounding_box_on_fixed_input(device):
# Check transformation against known expected output # Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64) spatial_size = (64, 64)
# xyxy format
in_boxes = [ in_boxes = [
[20, 25, 35, 45], [20, 25, 35, 45],
[50, 5, 70, 22], [50, 5, 70, 22],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5], [1, 1, 5, 5],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
# Tested parameters # Tested parameters
angle = 63 angle = 63
scale = 0.89 scale = 0.89
...@@ -686,11 +684,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -686,11 +684,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
output_boxes = F.affine_bounding_box( output_boxes = F.affine_bounding_box(
in_boxes, in_boxes,
in_boxes.format, format=format,
in_boxes.spatial_size, spatial_size=spatial_size,
angle, angle=angle,
(dx * spatial_size[1], dy * spatial_size[0]), translate=(dx * spatial_size[1], dy * spatial_size[0]),
scale, scale=scale,
shear=(0, 0), shear=(0, 0),
) )
...@@ -725,9 +723,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -725,9 +723,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = affine_matrix[:2, :] affine_matrix = affine_matrix[:2, :]
height, width = bbox.spatial_size height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -766,10 +762,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -766,10 +762,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return ( return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width)
convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width),
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -778,8 +771,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -778,8 +771,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
bboxes_spatial_size = bboxes.spatial_size bboxes_spatial_size = bboxes.spatial_size
output_bboxes, output_spatial_size = F.rotate_bounding_box( output_bboxes, output_spatial_size = F.rotate_bounding_box(
bboxes, bboxes.as_subclass(torch.Tensor),
bboxes_format, format=bboxes_format,
spatial_size=bboxes_spatial_size, spatial_size=bboxes_spatial_size,
angle=angle, angle=angle,
expand=expand, expand=expand,
...@@ -810,6 +803,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -810,6 +803,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output # Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64) spatial_size = (64, 64)
# xyxy format # xyxy format
in_boxes = [ in_boxes = [
...@@ -818,13 +812,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -818,13 +812,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
# Tested parameters # Tested parameters
angle = 45 angle = 45
center = None if expand else [12, 23] center = None if expand else [12, 23]
...@@ -854,9 +842,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -854,9 +842,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
output_boxes, _ = F.rotate_bounding_box( output_boxes, _ = F.rotate_bounding_box(
in_boxes, in_boxes,
in_boxes.format, format=format,
in_boxes.spatial_size, spatial_size=spatial_size,
angle, angle=angle,
expand=expand, expand=expand,
center=center, center=center,
) )
...@@ -906,16 +894,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -906,16 +894,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
# out_box = denormalize_bbox(n_out_box, height, width) # out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box) # expected_bboxes.append(out_box)
size = (64, 76) format = datapoints.BoundingBoxFormat.XYXY
# xyxy format spatial_size = (64, 76)
in_boxes = [ in_boxes = [
[10.0, 15.0, 25.0, 35.0], [10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, device=device)
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device
)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
...@@ -924,15 +910,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -924,15 +910,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
format, format,
top, top,
left, left,
size[0], spatial_size[0],
size[1], spatial_size[1],
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, spatial_size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -980,8 +966,8 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -980,8 +966,8 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
bbox[3] = (bbox[3] - top_) * size_[0] / height_ bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox return bbox
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (100, 100) spatial_size = (100, 100)
# xyxy format
in_boxes = [ in_boxes = [
[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0], [5.0, 10.0, 15.0, 20.0],
...@@ -1024,22 +1010,22 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -1024,22 +1010,22 @@ def test_correctness_pad_bounding_box(device, padding):
def _compute_expected_bbox(bbox, padding_): def _compute_expected_bbox(bbox, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_) pad_left, pad_up, _, _ = _parse_padding(padding_)
bbox_format = bbox.format dtype = bbox.dtype
bbox_dtype = bbox.dtype format = bbox.format
bbox = ( bbox = (
bbox.clone() bbox.clone()
if bbox_format == datapoints.BoundingBoxFormat.XYXY if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY) else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format) bbox = convert_format_bounding_box(bbox, new_format=format)
if bbox.dtype != bbox_dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
bbox = bbox.to(bbox_dtype) bbox = bbox.to(dtype)
return bbox return bbox
def _compute_expected_spatial_size(bbox, padding_): def _compute_expected_spatial_size(bbox, padding_):
...@@ -1108,9 +1094,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1108,9 +1094,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -1122,22 +1106,22 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1122,22 +1106,22 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
numer = np.matmul(points, m1.T) numer = np.matmul(points, m1.T)
denom = np.matmul(points, m2.T) denom = np.matmul(points, m2.T)
transformed_points = numer / denom transformed_points = numer / denom
out_bbox = [ out_bbox = np.array(
np.min(transformed_points[:, 0]), [
np.min(transformed_points[:, 1]), np.min(transformed_points[:, 0]),
np.max(transformed_points[:, 0]), np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 0]),
] np.max(transformed_points[:, 1]),
]
)
out_bbox = datapoints.BoundingBox( out_bbox = datapoints.BoundingBox(
np.array(out_bbox), out_bbox,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return convert_format_bounding_box(out_bbox, new_format=bbox.format)
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -1146,14 +1130,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1146,14 +1130,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size
output_bboxes = F.perspective_bounding_box( output_bboxes = F.perspective_bounding_box(
bboxes, bboxes.as_subclass(torch.Tensor),
bboxes_format, format=bboxes.format,
None, startpoints=None,
None, endpoints=None,
coefficients=pcoeffs, coefficients=pcoeffs,
) )
...@@ -1162,7 +1144,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1162,7 +1144,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment