Unverified Commit a63046ce authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only use plain tensors in kernel tests (#7230)

parent acabaf80
...@@ -237,13 +237,6 @@ class TensorLoader: ...@@ -237,13 +237,6 @@ class TensorLoader:
def load(self, device): def load(self, device):
return self.fn(self.shape, self.dtype, device) return self.fn(self.shape, self.dtype, device)
def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
......
...@@ -64,11 +64,19 @@ class DispatcherInfo(InfoBase): ...@@ -64,11 +64,19 @@ class DispatcherInfo(InfoBase):
if not filter_metadata: if not filter_metadata:
yield from sample_inputs yield from sample_inputs
else: return
import itertools
for args_kwargs in sample_inputs: for args_kwargs in sample_inputs:
for attribute in datapoint_type.__annotations__.keys(): for name in itertools.chain(
if attribute in args_kwargs.kwargs: datapoint_type.__annotations__.keys(),
del args_kwargs.kwargs[attribute] # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
yield args_kwargs yield args_kwargs
...@@ -458,4 +466,18 @@ DISPATCHER_INFOS = [ ...@@ -458,4 +466,18 @@ DISPATCHER_INFOS = [
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo(
F.clamp_bounding_box,
kernels={datapoints.BoundingBox: F.clamp_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.convert_format_bounding_box,
kernels={datapoints.BoundingBox: F.convert_format_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
] ]
...@@ -12,7 +12,6 @@ import torchvision.prototype.transforms.functional as F ...@@ -12,7 +12,6 @@ import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
BoundingBoxLoader,
get_num_channels, get_num_channels,
ImageLoader, ImageLoader,
InfoBase, InfoBase,
...@@ -337,7 +336,6 @@ def sample_inputs_resize_video(): ...@@ -337,7 +336,6 @@ def sample_inputs_resize_video():
def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None): def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None):
old_height, old_width = spatial_size old_height, old_width = spatial_size
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size) new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)
...@@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size= ...@@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
) )
expected_bboxes = reference_affine_bounding_box_helper( expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=bounding_box.format, affine_matrix=affine_matrix bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix
) )
return expected_bboxes, (new_height, new_width) return expected_bboxes, (new_height, new_width)
def reference_inputs_resize_bounding_box(): def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))): for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for size in _get_resize_sizes(bounding_box_loader.spatial_size): for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
...@@ -668,8 +668,7 @@ KERNEL_INFOS.extend( ...@@ -668,8 +668,7 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, new_format=new_format) yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
...@@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format): ...@@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def reference_inputs_convert_format_bounding_box(): def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) != 2: if len(args_kwargs.args[0].shape) == 2:
continue yield args_kwargs
(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
KERNEL_INFOS.append( KERNEL_INFOS.append(
...@@ -697,18 +690,6 @@ KERNEL_INFOS.append( ...@@ -697,18 +690,6 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
), ),
) )
...@@ -2049,10 +2030,8 @@ KERNEL_INFOS.extend( ...@@ -2049,10 +2030,8 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box(): def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader.unwrap(), bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_box_loader.spatial_size,
) )
...@@ -2063,19 +2042,6 @@ KERNEL_INFOS.append( ...@@ -2063,19 +2042,6 @@ KERNEL_INFOS.append(
F.clamp_bounding_box, F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
) )
) )
......
...@@ -121,8 +121,8 @@ class TestKernels: ...@@ -121,8 +121,8 @@ class TestKernels:
def test_logging(self, spy_on, info, args_kwargs, device): def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once) spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs) info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}") spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
...@@ -134,6 +134,7 @@ class TestKernels: ...@@ -134,6 +134,7 @@ class TestKernels:
kernel_scripted = script(kernel_eager) kernel_scripted = script(kernel_eager)
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
actual = kernel_scripted(input, *other_args, **kwargs) actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs) expected = kernel_eager(input, *other_args, **kwargs)
...@@ -155,14 +156,12 @@ class TestKernels: ...@@ -155,14 +156,12 @@ class TestKernels:
if batched_tensor.ndim == data_dims: if batched_tensor.ndim == data_dims:
return batch return batch
unbatcheds = [] return [
self._unbatch(unbatched, data_dims=data_dims)
for unbatched in ( for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
): )
if isinstance(batch, datapoints._datapoint.Datapoint): ]
unbatched = type(batch).wrap_like(batch, unbatched)
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -195,6 +194,7 @@ class TestKernels: ...@@ -195,6 +194,7 @@ class TestKernels:
elif not all(batched_input.shape[:-data_dims]): elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.") pytest.skip("Input has a degenerate batch shape.")
batched_input = batched_input.as_subclass(torch.Tensor)
batched_output = info.kernel(batched_input, *other_args, **kwargs) batched_output = info.kernel(batched_input, *other_args, **kwargs)
actual = self._unbatch(batched_output, data_dims=data_dims) actual = self._unbatch(batched_output, data_dims=data_dims)
...@@ -212,6 +212,7 @@ class TestKernels: ...@@ -212,6 +212,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device): def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
if input.numel() == 0: if input.numel() == 0:
pytest.skip("The input has a degenerate shape.") pytest.skip("The input has a degenerate shape.")
...@@ -225,6 +226,7 @@ class TestKernels: ...@@ -225,6 +226,7 @@ class TestKernels:
@needs_cuda @needs_cuda
def test_cuda_vs_cpu(self, test_id, info, args_kwargs): def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu") (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cpu = input_cpu.as_subclass(torch.Tensor)
input_cuda = input_cpu.to("cuda") input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs) output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
...@@ -242,6 +244,7 @@ class TestKernels: ...@@ -242,6 +244,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device): def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
output = info.kernel(input, *other_args, **kwargs) output = info.kernel(input, *other_args, **kwargs)
# Most kernels just return a tensor, but some also return some additional metadata # Most kernels just return a tensor, but some also return some additional metadata
...@@ -254,6 +257,7 @@ class TestKernels: ...@@ -254,6 +257,7 @@ class TestKernels:
@reference_inputs @reference_inputs
def test_against_reference(self, test_id, info, args_kwargs): def test_against_reference(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
actual = info.kernel(input, *other_args, **kwargs) actual = info.kernel(input, *other_args, **kwargs)
expected = info.reference_fn(input, *other_args, **kwargs) expected = info.reference_fn(input, *other_args, **kwargs)
...@@ -271,6 +275,7 @@ class TestKernels: ...@@ -271,6 +275,7 @@ class TestKernels:
) )
def test_float32_vs_uint8(self, test_id, info, args_kwargs): def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
if input.dtype != torch.uint8: if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.") pytest.skip(f"Input dtype is {input.dtype}.")
...@@ -341,7 +346,6 @@ class TestDispatchers: ...@@ -341,7 +346,6 @@ class TestDispatchers:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dispatcher", "dispatcher",
[ [
F.clamp_bounding_box,
F.get_dimensions, F.get_dimensions,
F.get_image_num_channels, F.get_image_num_channels,
F.get_image_size, F.get_image_size,
...@@ -647,21 +651,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -647,21 +651,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_bounding_box_on_fixed_input(device): def test_correctness_affine_bounding_box_on_fixed_input(device):
# Check transformation against known expected output # Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64) spatial_size = (64, 64)
# xyxy format
in_boxes = [ in_boxes = [
[20, 25, 35, 45], [20, 25, 35, 45],
[50, 5, 70, 22], [50, 5, 70, 22],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5], [1, 1, 5, 5],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
# Tested parameters # Tested parameters
angle = 63 angle = 63
scale = 0.89 scale = 0.89
...@@ -686,11 +684,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -686,11 +684,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
output_boxes = F.affine_bounding_box( output_boxes = F.affine_bounding_box(
in_boxes, in_boxes,
in_boxes.format, format=format,
in_boxes.spatial_size, spatial_size=spatial_size,
angle, angle=angle,
(dx * spatial_size[1], dy * spatial_size[0]), translate=(dx * spatial_size[1], dy * spatial_size[0]),
scale, scale=scale,
shear=(0, 0), shear=(0, 0),
) )
...@@ -725,9 +723,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -725,9 +723,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = affine_matrix[:2, :] affine_matrix = affine_matrix[:2, :]
height, width = bbox.spatial_size height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -766,10 +762,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -766,10 +762,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return ( return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width)
convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width),
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -778,8 +771,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -778,8 +771,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
bboxes_spatial_size = bboxes.spatial_size bboxes_spatial_size = bboxes.spatial_size
output_bboxes, output_spatial_size = F.rotate_bounding_box( output_bboxes, output_spatial_size = F.rotate_bounding_box(
bboxes, bboxes.as_subclass(torch.Tensor),
bboxes_format, format=bboxes_format,
spatial_size=bboxes_spatial_size, spatial_size=bboxes_spatial_size,
angle=angle, angle=angle,
expand=expand, expand=expand,
...@@ -810,6 +803,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -810,6 +803,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output # Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64) spatial_size = (64, 64)
# xyxy format # xyxy format
in_boxes = [ in_boxes = [
...@@ -818,13 +812,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -818,13 +812,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
# Tested parameters # Tested parameters
angle = 45 angle = 45
center = None if expand else [12, 23] center = None if expand else [12, 23]
...@@ -854,9 +842,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -854,9 +842,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
output_boxes, _ = F.rotate_bounding_box( output_boxes, _ = F.rotate_bounding_box(
in_boxes, in_boxes,
in_boxes.format, format=format,
in_boxes.spatial_size, spatial_size=spatial_size,
angle, angle=angle,
expand=expand, expand=expand,
center=center, center=center,
) )
...@@ -906,16 +894,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -906,16 +894,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
# out_box = denormalize_bbox(n_out_box, height, width) # out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box) # expected_bboxes.append(out_box)
size = (64, 76) format = datapoints.BoundingBoxFormat.XYXY
# xyxy format spatial_size = (64, 76)
in_boxes = [ in_boxes = [
[10.0, 15.0, 25.0, 35.0], [10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = datapoints.BoundingBox( in_boxes = torch.tensor(in_boxes, device=device)
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device
)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
...@@ -924,15 +910,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -924,15 +910,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
format, format,
top, top,
left, left,
size[0], spatial_size[0],
size[1], spatial_size[1],
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, spatial_size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -980,8 +966,8 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -980,8 +966,8 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
bbox[3] = (bbox[3] - top_) * size_[0] / height_ bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox return bbox
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (100, 100) spatial_size = (100, 100)
# xyxy format
in_boxes = [ in_boxes = [
[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0], [5.0, 10.0, 15.0, 20.0],
...@@ -1024,22 +1010,22 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -1024,22 +1010,22 @@ def test_correctness_pad_bounding_box(device, padding):
def _compute_expected_bbox(bbox, padding_): def _compute_expected_bbox(bbox, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_) pad_left, pad_up, _, _ = _parse_padding(padding_)
bbox_format = bbox.format dtype = bbox.dtype
bbox_dtype = bbox.dtype format = bbox.format
bbox = ( bbox = (
bbox.clone() bbox.clone()
if bbox_format == datapoints.BoundingBoxFormat.XYXY if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY) else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format) bbox = convert_format_bounding_box(bbox, new_format=format)
if bbox.dtype != bbox_dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
bbox = bbox.to(bbox_dtype) bbox = bbox.to(dtype)
return bbox return bbox
def _compute_expected_spatial_size(bbox, padding_): def _compute_expected_spatial_size(bbox, padding_):
...@@ -1108,9 +1094,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1108,9 +1094,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -1122,22 +1106,22 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1122,22 +1106,22 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
numer = np.matmul(points, m1.T) numer = np.matmul(points, m1.T)
denom = np.matmul(points, m2.T) denom = np.matmul(points, m2.T)
transformed_points = numer / denom transformed_points = numer / denom
out_bbox = [ out_bbox = np.array(
[
np.min(transformed_points[:, 0]), np.min(transformed_points[:, 0]),
np.min(transformed_points[:, 1]), np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 0]), np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
)
out_bbox = datapoints.BoundingBox( out_bbox = datapoints.BoundingBox(
np.array(out_bbox), out_bbox,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return convert_format_bounding_box(out_bbox, new_format=bbox.format)
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -1146,14 +1130,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1146,14 +1130,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size
output_bboxes = F.perspective_bounding_box( output_bboxes = F.perspective_bounding_box(
bboxes, bboxes.as_subclass(torch.Tensor),
bboxes_format, format=bboxes.format,
None, startpoints=None,
None, endpoints=None,
coefficients=pcoeffs, coefficients=pcoeffs,
) )
...@@ -1162,7 +1144,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1162,7 +1144,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment