Unverified Commit 0b5ebae6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor prototype transforms functional tests (#5879)

parent 57ae04b4
This diff is collapsed.
import dataclasses
import functools
import itertools
import math
from typing import Any, Callable, Dict, Iterable, Optional
import numpy as np
import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
from torchvision.prototype import features
__all__ = ["KernelInfo", "KERNEL_INFOS"]
@dataclasses.dataclass
class KernelInfo:
kernel: Callable
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
# not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
# tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn: Optional[Callable] = None
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self):
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
atol=1e-5,
rtol=0,
agg_method="mean",
)
def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(image_tensor, *other_args, **kwargs):
if image_tensor.ndim > 3:
raise pytest.UsageError(
f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}"
)
# We don't need to convert back to tensor here, since `assert_close` does that automatically.
return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs)
return wrapper
KERNEL_INFOS = []
def sample_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(dtypes=[torch.float32]):
yield ArgsKwargs(image_loader)
def reference_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()]):
yield ArgsKwargs(image_loader)
def sample_inputs_horizontal_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size
)
def sample_inputs_horizontal_flip_mask():
for image_loader in make_mask_loaders(dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
KERNEL_INFOS.extend(
[
KernelInfo(
F.horizontal_flip_image_tensor,
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.horizontal_flip_bounding_box,
sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box,
),
KernelInfo(
F.horizontal_flip_mask,
sample_inputs_fn=sample_inputs_horizontal_flip_mask,
),
]
)
def sample_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product(
make_image_loaders(dtypes=[torch.float32]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
):
height, width = image_loader.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield ArgsKwargs(image_loader, size=size, interpolation=interpolation)
def reference_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
):
height, width = image_loader.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield ArgsKwargs(image_loader, size=size, interpolation=interpolation)
def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
height, width = bounding_box_loader.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size)
KERNEL_INFOS.extend(
[
KernelInfo(
F.resize_image_tensor,
sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=pil_reference_wrapper(F.resize_image_pil),
reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box,
),
]
)
_AFFINE_KWARGS = combinations_grid(
angle=[-87, 15, 90],
translate=[(5, 5), (-5, -5)],
scale=[0.77, 1.27],
shear=[(12, 12), (0, 0)],
)
def sample_inputs_affine_image_tensor():
for image_loader, interpolation_mode, center in itertools.product(
make_image_loaders(dtypes=[torch.float32]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
],
[None, (0, 0)],
):
for fill in [None, [0.5] * image_loader.num_channels]:
yield ArgsKwargs(
image_loader,
interpolation=interpolation_mode,
center=center,
fill=fill,
**_AFFINE_KWARGS[0],
)
def reference_inputs_affine_image_tensor():
for image, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS):
yield ArgsKwargs(
image,
interpolation=F.InterpolationMode.NEAREST,
**affine_kwargs,
)
def sample_inputs_affine_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size,
**_AFFINE_KWARGS[0],
)
def _compute_affine_matrix(angle, translate, scale, shear, center):
rot = math.radians(angle)
cx, cy = center
tx, ty = translate
sx, sy = [math.radians(sh_) for sh_ in shear]
c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
c_matrix_inv = np.linalg.inv(c_matrix)
rs_matrix = np.array(
[
[scale * math.cos(rot), -scale * math.sin(rot), 0],
[scale * math.sin(rot), scale * math.cos(rot), 0],
[0, 0, 1],
]
)
shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
return true_matrix
def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center):
if center is None:
center = [s * 0.5 for s in image_size[::-1]]
def transform(bbox):
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]
bbox_xyxy = F.convert_format_bounding_box(bbox, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = torch.tensor(
[
np.min(transformed_points[:, 0]),
np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]),
],
dtype=bbox.dtype,
)
return F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
)
if bounding_box.ndim < 2:
bounding_box = [bounding_box]
expected_bboxes = [transform(bbox) for bbox in bounding_box]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
return expected_bboxes
def reference_inputs_affine_bounding_box():
for bounding_box_loader, angle, translate, scale, shear, center in itertools.product(
make_bounding_box_loaders(extra_dims=[(4,)], image_size=(32, 38), dtypes=[torch.float32]),
range(-90, 90, 56),
range(-10, 10, 8),
[0.77, 1.0, 1.27],
range(-15, 15, 8),
[None, (12, 14)],
):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
)
KERNEL_INFOS.extend(
[
KernelInfo(
F.affine_image_tensor,
sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.affine_bounding_box,
sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
),
]
)
...@@ -1587,7 +1587,7 @@ class TestFixedSizeCrop: ...@@ -1587,7 +1587,7 @@ class TestFixedSizeCrop:
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
) )
masks = make_detection_mask(size=image_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
......
...@@ -48,24 +48,6 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): ...@@ -48,24 +48,6 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
return sample_inputs_fn return sample_inputs_fn
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image_tensor():
for image in make_images():
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_mask():
for mask in make_masks():
yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_image_tensor(): def vertical_flip_image_tensor():
for image in make_images(): for image in make_images():
...@@ -84,44 +66,6 @@ def vertical_flip_mask(): ...@@ -84,44 +66,6 @@ def vertical_flip_mask():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn
def resize_image_tensor():
for image, interpolation, max_size, antialias in itertools.product(
make_images(),
[F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation
[None, 34], # max_size
[False, True], # antialias
):
if antialias and interpolation == F.InterpolationMode.NEAREST:
continue
height, width = image.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
if max_size is not None:
size = [size[0]]
yield ArgsKwargs(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn
def resize_bounding_box():
for bounding_box, max_size in itertools.product(
make_bounding_boxes(),
[None, 34], # max_size
):
height, width = bounding_box.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
if max_size is not None:
size = [size[0]]
yield ArgsKwargs(bounding_box, size=size, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resize_mask(): def resize_mask():
for mask, max_size in itertools.product( for mask, max_size in itertools.product(
...@@ -138,45 +82,6 @@ def resize_mask(): ...@@ -138,45 +82,6 @@ def resize_mask():
yield ArgsKwargs(mask, size=size, max_size=max_size) yield ArgsKwargs(mask, size=size, max_size=max_size)
@register_kernel_info_from_sample_inputs_fn
def affine_image_tensor():
for image, angle, translate, scale, shear in itertools.product(
make_images(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
[0, 12], # shear
):
yield ArgsKwargs(
image,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
interpolation=F.InterpolationMode.NEAREST,
)
@register_kernel_info_from_sample_inputs_fn
def affine_bounding_box():
for bounding_box, angle, translate, scale, shear in itertools.product(
make_bounding_boxes(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
[0, 12], # shear
):
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
image_size=bounding_box.image_size,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def affine_mask(): def affine_mask():
for mask, angle, translate, scale, shear in itertools.product( for mask, angle, translate, scale, shear in itertools.product(
...@@ -664,12 +569,7 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center) ...@@ -664,12 +569,7 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center)
image_size = (32, 38) image_size = (32, 38)
for bboxes in make_bounding_boxes( for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)):
image_sizes=[
image_size,
],
extra_dims=((4,),),
):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
...@@ -882,12 +782,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -882,12 +782,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
image_size = (32, 38) image_size = (32, 38)
for bboxes in make_bounding_boxes( for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)):
image_sizes=[
image_size,
],
extra_dims=((4,),),
):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
...@@ -1432,12 +1327,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1432,12 +1327,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
for bboxes in make_bounding_boxes( for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)):
image_sizes=[
image_size,
],
extra_dims=((4,),),
):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
...@@ -1466,7 +1356,8 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1466,7 +1356,8 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"startpoints, endpoints", "startpoints, endpoints",
[ [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], # FIXME: this configuration leads to a difference in a single pixel
# [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
], ],
...@@ -1550,10 +1441,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1550,10 +1441,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
) )
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False)
for bboxes in make_bounding_boxes( for bboxes in make_bounding_boxes(extra_dims=((4,),)):
image_sizes=[(32, 32), (24, 33), (32, 25)],
extra_dims=((4,),),
):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
......
import pytest
import torch.testing
from common_utils import cpu_and_gpu, needs_cuda
from prototype_common_utils import assert_close
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F
def test_coverage():
tested = {info.kernel.__name__ for info in KERNEL_INFOS}
exposed = {
name
for name, kernel in F.__dict__.items()
if callable(kernel)
and any(
name.endswith(f"_{feature_name}")
for feature_name in {
"bounding_box",
"image_tensor",
"label",
"mask",
}
)
and name not in {"to_image_tensor"}
# TODO: The list below should be quickly reduced in the transition period. There is nothing that prevents us
# from adding `KernelInfo`'s for these kernels other than time.
and name
not in {
"adjust_brightness_image_tensor",
"adjust_contrast_image_tensor",
"adjust_gamma_image_tensor",
"adjust_hue_image_tensor",
"adjust_saturation_image_tensor",
"adjust_sharpness_image_tensor",
"affine_mask",
"autocontrast_image_tensor",
"center_crop_bounding_box",
"center_crop_image_tensor",
"center_crop_mask",
"clamp_bounding_box",
"convert_color_space_image_tensor",
"convert_format_bounding_box",
"crop_bounding_box",
"crop_image_tensor",
"crop_mask",
"elastic_bounding_box",
"elastic_image_tensor",
"elastic_mask",
"equalize_image_tensor",
"erase_image_tensor",
"five_crop_image_tensor",
"gaussian_blur_image_tensor",
"horizontal_flip_image_tensor",
"invert_image_tensor",
"normalize_image_tensor",
"pad_bounding_box",
"pad_image_tensor",
"pad_mask",
"perspective_bounding_box",
"perspective_image_tensor",
"perspective_mask",
"posterize_image_tensor",
"resize_mask",
"resized_crop_bounding_box",
"resized_crop_image_tensor",
"resized_crop_mask",
"rotate_bounding_box",
"rotate_image_tensor",
"rotate_mask",
"solarize_image_tensor",
"ten_crop_image_tensor",
"vertical_flip_bounding_box",
"vertical_flip_image_tensor",
"vertical_flip_mask",
}
}
untested = exposed - tested
if untested:
raise AssertionError(
f"The kernel(s) {sequence_to_str(sorted(untested), separate_last='and ')} "
f"are exposed through `torchvision.prototype.transforms.functional`, but are not tested. "
f"Please add a `KernelInfo` to the `KERNEL_INFOS` list in `test/prototype_transforms_kernel_infos.py`."
)
class TestCommon:
sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel.__name__}")
for info in KERNEL_INFOS
for args_kwargs in info.sample_inputs_fn()
],
)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device):
kernel_eager = info.kernel
try:
kernel_scripted = torch.jit.script(kernel_eager)
except Exception as error:
raise AssertionError("Trying to `torch.jit.script` the kernel raised the error above.") from error
args, kwargs = args_kwargs.load(device)
actual = kernel_scripted(*args, **kwargs)
expected = kernel_eager(*args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
def unbind_batch_dims(batched_tensor, *, data_dims):
if batched_tensor.ndim == data_dims:
return batched_tensor
return [unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)]
def stack_batch_dims(unbound_tensor):
if isinstance(unbound_tensor[0], torch.Tensor):
return torch.stack(unbound_tensor)
return torch.stack([stack_batch_dims(t) for t in unbound_tensor])
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
features.Image: 3,
features.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
features.Mask: 2,
}.get(feature_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}."
) from None
elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.")
elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.")
actual = info.kernel(batched_input, *other_args, **kwargs)
single_inputs = unbind_batch_dims(batched_input, data_dims=data_dims)
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
expected = stack_batch_dims(single_outputs)
assert_close(actual, expected, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
if input.numel() == 0:
pytest.skip("The input has a degenerate shape.")
input_version = input._version
output = info.kernel(input, *other_args, **kwargs)
assert output is not input or output._version == input_version
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False)
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel.__name__}")
for info in KERNEL_INFOS
for args_kwargs in info.reference_inputs_fn()
if info.reference_fn is not None
],
)
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
expected = info.reference_fn(*args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs, check_dtype=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment