Unverified Commit a63046ce authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only use plain tensors in kernel tests (#7230)

parent acabaf80
......@@ -237,13 +237,6 @@ class TensorLoader:
def load(self, device):
return self.fn(self.shape, self.dtype, device)
def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)
@dataclasses.dataclass
class ImageLoader(TensorLoader):
......
......@@ -64,11 +64,19 @@ class DispatcherInfo(InfoBase):
if not filter_metadata:
yield from sample_inputs
else:
return
import itertools
for args_kwargs in sample_inputs:
for attribute in datapoint_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
for name in itertools.chain(
datapoint_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
yield args_kwargs
......@@ -458,4 +466,18 @@ DISPATCHER_INFOS = [
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.clamp_bounding_box,
kernels={datapoints.BoundingBox: F.clamp_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.convert_format_bounding_box,
kernels={datapoints.BoundingBox: F.convert_format_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
]
......@@ -12,7 +12,6 @@ import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
BoundingBoxLoader,
get_num_channels,
ImageLoader,
InfoBase,
......@@ -337,7 +336,6 @@ def sample_inputs_resize_video():
def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None):
old_height, old_width = spatial_size
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)
......@@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=bounding_box.format, affine_matrix=affine_matrix
bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix
)
return expected_bboxes, (new_height, new_width)
def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))):
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
......@@ -668,8 +668,7 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
......@@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) != 2:
continue
(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
KERNEL_INFOS.append(
......@@ -697,18 +690,6 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
),
)
......@@ -2049,10 +2030,8 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
yield ArgsKwargs(
bounding_box_loader.unwrap(),
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
......@@ -2063,19 +2042,6 @@ KERNEL_INFOS.append(
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
)
)
......
......@@ -121,8 +121,8 @@ class TestKernels:
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs)
(input, *other_args), kwargs = args_kwargs.load(device)
info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
......@@ -134,6 +134,7 @@ class TestKernels:
kernel_scripted = script(kernel_eager)
(input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs)
......@@ -155,14 +156,12 @@ class TestKernels:
if batched_tensor.ndim == data_dims:
return batch
unbatcheds = []
return [
self._unbatch(unbatched, data_dims=data_dims)
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
):
if isinstance(batch, datapoints._datapoint.Datapoint):
unbatched = type(batch).wrap_like(batch, unbatched)
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds
)
]
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -195,6 +194,7 @@ class TestKernels:
elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.")
batched_input = batched_input.as_subclass(torch.Tensor)
batched_output = info.kernel(batched_input, *other_args, **kwargs)
actual = self._unbatch(batched_output, data_dims=data_dims)
......@@ -212,6 +212,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
if input.numel() == 0:
pytest.skip("The input has a degenerate shape.")
......@@ -225,6 +226,7 @@ class TestKernels:
@needs_cuda
def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cpu = input_cpu.as_subclass(torch.Tensor)
input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
......@@ -242,6 +244,7 @@ class TestKernels:
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
input = input.as_subclass(torch.Tensor)
output = info.kernel(input, *other_args, **kwargs)
# Most kernels just return a tensor, but some also return some additional metadata
......@@ -254,6 +257,7 @@ class TestKernels:
@reference_inputs
def test_against_reference(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
actual = info.kernel(input, *other_args, **kwargs)
expected = info.reference_fn(input, *other_args, **kwargs)
......@@ -271,6 +275,7 @@ class TestKernels:
)
def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.")
......@@ -341,7 +346,6 @@ class TestDispatchers:
@pytest.mark.parametrize(
"dispatcher",
[
F.clamp_bounding_box,
F.get_dimensions,
F.get_image_num_channels,
F.get_image_size,
......@@ -647,21 +651,15 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_bounding_box_on_fixed_input(device):
# Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64)
# xyxy format
in_boxes = [
[20, 25, 35, 45],
[50, 5, 70, 22],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5],
]
in_boxes = datapoints.BoundingBox(
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
# Tested parameters
angle = 63
scale = 0.89
......@@ -686,11 +684,11 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
output_boxes = F.affine_bounding_box(
in_boxes,
in_boxes.format,
in_boxes.spatial_size,
angle,
(dx * spatial_size[1], dy * spatial_size[0]),
scale,
format=format,
spatial_size=spatial_size,
angle=angle,
translate=(dx * spatial_size[1], dy * spatial_size[0]),
scale=scale,
shear=(0, 0),
)
......@@ -725,9 +723,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = affine_matrix[:2, :]
height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -766,10 +762,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
dtype=bbox.dtype,
device=bbox.device,
)
return (
convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width),
)
return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width)
spatial_size = (32, 38)
......@@ -778,8 +771,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
bboxes_spatial_size = bboxes.spatial_size
output_bboxes, output_spatial_size = F.rotate_bounding_box(
bboxes,
bboxes_format,
bboxes.as_subclass(torch.Tensor),
format=bboxes_format,
spatial_size=bboxes_spatial_size,
angle=angle,
expand=expand,
......@@ -810,6 +803,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64)
# xyxy format
in_boxes = [
......@@ -818,13 +812,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
]
in_boxes = datapoints.BoundingBox(
in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
)
in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
# Tested parameters
angle = 45
center = None if expand else [12, 23]
......@@ -854,9 +842,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
output_boxes, _ = F.rotate_bounding_box(
in_boxes,
in_boxes.format,
in_boxes.spatial_size,
angle,
format=format,
spatial_size=spatial_size,
angle=angle,
expand=expand,
center=center,
)
......@@ -906,16 +894,14 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)
size = (64, 76)
# xyxy format
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 76)
in_boxes = [
[10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0],
]
in_boxes = datapoints.BoundingBox(
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device
)
in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
......@@ -924,15 +910,15 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
format,
top,
left,
size[0],
size[1],
spatial_size[0],
spatial_size[1],
)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, size)
torch.testing.assert_close(output_spatial_size, spatial_size)
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -980,8 +966,8 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (100, 100)
# xyxy format
in_boxes = [
[10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0],
......@@ -1024,22 +1010,22 @@ def test_correctness_pad_bounding_box(device, padding):
def _compute_expected_bbox(bbox, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_)
bbox_format = bbox.format
bbox_dtype = bbox.dtype
dtype = bbox.dtype
format = bbox.format
bbox = (
bbox.clone()
if bbox_format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY)
if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
)
bbox[0::2] += pad_left
bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format)
if bbox.dtype != bbox_dtype:
bbox = convert_format_bounding_box(bbox, new_format=format)
if bbox.dtype != dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
bbox = bbox.to(bbox_dtype)
bbox = bbox.to(dtype)
return bbox
def _compute_expected_spatial_size(bbox, padding_):
......@@ -1108,9 +1094,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
]
)
bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
)
bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -1122,22 +1106,22 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
numer = np.matmul(points, m1.T)
denom = np.matmul(points, m2.T)
transformed_points = numer / denom
out_bbox = [
out_bbox = np.array(
[
np.min(transformed_points[:, 0]),
np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]),
]
)
out_bbox = datapoints.BoundingBox(
np.array(out_bbox),
out_bbox,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size,
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format
)
return convert_format_bounding_box(out_bbox, new_format=bbox.format)
spatial_size = (32, 38)
......@@ -1146,14 +1130,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size
output_bboxes = F.perspective_bounding_box(
bboxes,
bboxes_format,
None,
None,
bboxes.as_subclass(torch.Tensor),
format=bboxes.format,
startpoints=None,
endpoints=None,
coefficients=pcoeffs,
)
......@@ -1162,7 +1144,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment