"vscode:/vscode.git/clone" did not exist on "75c8a97a3c4bc889875ea7f9a56dcf89f817ac30"
Unverified Commit 65769ab7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix prototype transforms tests with set agg_method (#6934)

* fix prototype transforms tests with set agg_method

* use individual tolerances

* refactor PIL reference test

* increase tolerance for elastic_mask

* fix autocontrast tolerances

* increase tolerance for RandomAutocontrast
parent d72e9064
......@@ -12,17 +12,9 @@ import torch
import torch.testing
from datasets_utils import combinations_grid
from torch.nn.functional import one_hot
from torch.testing._comparison import (
assert_equal as _assert_equal,
BooleanPair,
ErrorMeta,
NonePair,
NumberPair,
TensorLikePair,
UnsupportedInputs,
)
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = [
......@@ -54,7 +46,7 @@ __all__ = [
]
class PILImagePair(TensorLikePair):
class ImagePair(TensorLikePair):
def __init__(
self,
actual,
......@@ -64,44 +56,13 @@ class PILImagePair(TensorLikePair):
allowed_percentage_diff=None,
**other_parameters,
):
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)):
raise UnsupportedInputs()
# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
other_parameters["check_device"] = False
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)
return super()._equalize_attributes(actual, expected)
def compare(self) -> None:
actual, expected = self.actual, self.expected
......@@ -111,16 +72,24 @@ class PILImagePair(TensorLikePair):
abs_diff = torch.abs(actual - expected)
if self.allowed_percentage_diff is not None:
percentage_diff = (abs_diff != 0).to(torch.float).mean()
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff:
self._make_error_meta(AssertionError, "percentage mismatch")
raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
)
if self.agg_method is None:
super()._compare_values(actual, expected)
else:
err = self.agg_method(abs_diff.to(torch.float64))
if err > self.atol:
self._make_error_meta(AssertionError, "aggregated mismatch")
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if agg_abs_diff > self.atol:
raise self._make_error_meta(
AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)
def assert_close(
......@@ -148,7 +117,7 @@ def assert_close(
NonePair,
BooleanPair,
NumberPair,
PILImagePair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
......@@ -167,6 +136,32 @@ def assert_close(
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
else:
return repr(obj)
if args or kwargs:
postfix = "\n".join(
[
"",
"Failure happened for the following parameters:",
"",
*[to_str(arg) for arg in args],
*[f"{name}={to_str(kwarg)}" for name, kwarg in kwargs.items()],
]
)
else:
postfix = ""
def wrapper(msg):
return msg + postfix
return wrapper
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
......@@ -656,6 +651,13 @@ class InfoBase:
]
def get_closeness_kwargs(self, test_id, *, dtype, device):
if not (isinstance(test_id, tuple) and len(test_id) == 2):
msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
if callable(test_id):
msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
else:
msg += f", but got {test_id} instead."
raise pytest.UsageError(msg)
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())
......@@ -4,6 +4,7 @@ import itertools
import math
import numpy as np
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
......@@ -49,6 +50,12 @@ class KernelInfo(InfoBase):
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None,
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
# the reference inputs. This is usually used whenever we use a PIL kernel as reference.
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8=False,
# See InfoBase
test_marks=None,
# See InfoBase
......@@ -60,28 +67,64 @@ class KernelInfo(InfoBase):
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn
if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8
DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
}
CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE = {
(("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): dict(atol=atol, rtol=0)
for dtype, atol in [(torch.uint8, 1), (torch.float32, 1 / 255)]
}
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, agg_method=None):
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, agg_method=agg_method)
def cuda_vs_cpu_pixel_difference(atol=1):
return {
(("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): _pixel_difference_closeness_kwargs(atol, dtype=dtype)
for dtype in [torch.uint8, torch.float32]
}
def pil_reference_pixel_difference(atol=1, agg_method=None):
return {
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs(
atol, agg_method=agg_method
)
}
def float32_vs_uint8_pixel_difference(atol=1, agg_method=None):
return {
(
("TestKernels", "test_float32_vs_uint8"),
torch.float32,
"cpu",
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, agg_method=agg_method)
}
def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(image_tensor, *other_args, **kwargs):
if image_tensor.ndim > 3:
def wrapper(input_tensor, *other_args, **kwargs):
if input_tensor.dtype != torch.uint8:
raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
if input_tensor.ndim > 3:
raise pytest.UsageError(
f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}"
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
)
# We don't need to convert back to tensor here, since `assert_close` does that automatically.
return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs)
input_pil = F.to_image_pil(input_tensor)
output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image):
return output_pil
output_tensor = F.to_image_tensor(output_pil)
# 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3:
output_tensor = output_tensor.unsqueeze(0)
elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
output_tensor = output_tensor.squeeze(0)
return output_tensor
return wrapper
......@@ -126,7 +169,7 @@ def sample_inputs_horizontal_flip_image_tensor():
def reference_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()]):
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
......@@ -180,7 +223,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
),
KernelInfo(
F.horizontal_flip_bounding_box,
......@@ -244,7 +287,7 @@ def reference_resize_image_tensor(*args, **kwargs):
def reference_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]),
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT,
......@@ -324,9 +367,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
# TODO: investigate
**pil_reference_pixel_difference(110, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
# TODO: investigate
**float32_vs_uint8_pixel_difference(50),
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
......@@ -346,7 +393,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
......@@ -354,7 +402,7 @@ KERNEL_INFOS.extend(
KernelInfo(
F.resize_video,
sample_inputs_fn=sample_inputs_resize_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
......@@ -400,6 +448,36 @@ _DIVERSE_AFFINE_PARAMS = [
]
def get_fills(*, num_channels, dtype, vector=True):
yield None
max_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value
yield max_value / 2
yield max_value
if not vector:
return
if dtype.is_floating_point:
yield [0.1 + c / 10 for c in range(num_channels)]
else:
yield [12.0 + c for c in range(num_channels)]
def float32_vs_uint8_fill_adapter(other_args, kwargs):
fill = kwargs.get("fill")
if fill is None:
return other_args, kwargs
if isinstance(fill, (int, float)):
fill /= 255
else:
fill = type(fill)(fill_ / 255 for fill_ in fill)
return other_args, dict(kwargs, fill=fill)
def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
......@@ -409,10 +487,7 @@ def sample_inputs_affine_image_tensor():
yield ArgsKwargs(image_loader, **affine_params)
for image_loader in make_affine_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill)
for image_loader, interpolation in itertools.product(
......@@ -426,7 +501,9 @@ def sample_inputs_affine_image_tensor():
def reference_inputs_affine_image_tensor():
for image_loader, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS):
for image_loader, affine_kwargs in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _AFFINE_KWARGS
):
yield ArgsKwargs(
image_loader,
interpolation=F.InterpolationMode.NEAREST,
......@@ -564,7 +641,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10, agg_method="mean"),
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
......@@ -589,7 +667,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
closeness_kwargs=pil_reference_pixel_difference(10),
float32_vs_uint8=True,
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
......@@ -631,7 +710,9 @@ KERNEL_INFOS.append(
def sample_inputs_convert_color_space_image_tensor():
color_spaces = list(set(features.ColorSpace) - {features.ColorSpace.OTHER})
color_spaces = sorted(
set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value
)
for old_color_space, new_color_space in cycle_over(color_spaces):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True):
......@@ -659,7 +740,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c
def reference_inputs_convert_color_space_image_tensor():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 3:
if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8:
yield args_kwargs
......@@ -678,7 +759,10 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.convert_color_space_video,
......@@ -694,7 +778,7 @@ def sample_inputs_vertical_flip_image_tensor():
def reference_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()]):
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
......@@ -739,7 +823,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
),
KernelInfo(
F.vertical_flip_bounding_box,
......@@ -775,10 +859,7 @@ def sample_inputs_rotate_image_tensor():
yield ArgsKwargs(image_loader, angle=15.0, center=center)
for image_loader in make_rotate_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, angle=15.0, fill=fill)
for image_loader, interpolation in itertools.product(
......@@ -789,7 +870,9 @@ def sample_inputs_rotate_image_tensor():
def reference_inputs_rotate_image_tensor():
for image_loader, angle in itertools.product(make_image_loaders(extra_dims=[()]), _ROTATE_ANGLES):
for image_loader, angle in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _ROTATE_ANGLES
):
yield ArgsKwargs(image_loader, angle=angle)
......@@ -830,7 +913,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_image_tensor,
reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......@@ -846,7 +931,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_mask,
reference_fn=reference_rotate_mask,
reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
),
KernelInfo(
F.rotate_video,
......@@ -873,7 +959,9 @@ def sample_inputs_crop_image_tensor():
def reference_inputs_crop_image_tensor():
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _CROP_PARAMS):
for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS
):
yield ArgsKwargs(image_loader, **params)
......@@ -928,7 +1016,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
),
KernelInfo(
F.crop_bounding_box,
......@@ -941,7 +1029,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
),
KernelInfo(
F.crop_video,
......@@ -970,7 +1058,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
def reference_inputs_resized_crop_image_tensor():
for image_loader, interpolation, params in itertools.product(
make_image_loaders(extra_dims=[()]),
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT,
......@@ -1020,9 +1108,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
# TODO: investigate
**pil_reference_pixel_difference(60, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
# TODO: investigate
**float32_vs_uint8_pixel_difference(50),
},
),
KernelInfo(
......@@ -1034,12 +1126,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_mask,
reference_fn=pil_reference_wrapper(F.resized_crop_image_pil),
reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
),
KernelInfo(
F.resized_crop_video,
sample_inputs_fn=sample_inputs_resized_crop_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
......@@ -1062,10 +1155,7 @@ def sample_inputs_pad_image_tensor():
yield ArgsKwargs(image_loader, padding=padding)
for image_loader in make_pad_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, padding=[1], fill=fill)
for image_loader, padding_mode in itertools.product(
......@@ -1082,12 +1172,15 @@ def sample_inputs_pad_image_tensor():
def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _PAD_PARAMS):
for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
fills = [None, 128.0, 128]
if params["padding_mode"] == "constant":
fills.append([12.0 + c for c in range(image_loader.num_channels)])
for fill in fills:
for fill in get_fills(
num_channels=image_loader.num_channels,
dtype=image_loader.dtype,
vector=params["padding_mode"] == "constant",
):
yield ArgsKwargs(image_loader, fill=fill, **params)
......@@ -1110,8 +1203,10 @@ def sample_inputs_pad_mask():
def reference_inputs_pad_mask():
for image_loader, fill, params in itertools.product(make_image_loaders(extra_dims=[()]), [None, 127], _PAD_PARAMS):
yield ArgsKwargs(image_loader, fill=fill, **params)
for mask_loader, fill, params in itertools.product(
make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
):
yield ArgsKwargs(mask_loader, fill=fill, **params)
def sample_inputs_pad_video():
......@@ -1158,7 +1253,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
......@@ -1180,7 +1276,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
),
KernelInfo(
F.pad_video,
......@@ -1197,14 +1293,16 @@ _PERSPECTIVE_COEFFS = [
def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]):
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0])
def reference_inputs_perspective_image_tensor():
for image_loader, coefficients in itertools.product(make_image_loaders(extra_dims=[()]), _PERSPECTIVE_COEFFS):
for image_loader, coefficients in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PERSPECTIVE_COEFFS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in [None, 128.0, 128, [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=coefficients)
......@@ -1239,9 +1337,12 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
# TODO: investigate
**pil_reference_pixel_difference(160, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
......@@ -1253,12 +1354,15 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
},
),
KernelInfo(
F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
......@@ -1271,13 +1375,13 @@ def _get_elastic_displacement(spatial_size):
def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]):
displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
def reference_inputs_elastic_image_tensor():
for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]),
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
......@@ -1285,7 +1389,7 @@ def reference_inputs_elastic_image_tensor():
],
):
displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
......@@ -1324,7 +1428,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
# TODO: investigate
closeness_kwargs=float32_vs_uint8_pixel_difference(60, agg_method="mean"),
),
KernelInfo(
F.elastic_bounding_box,
......@@ -1335,7 +1441,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_mask,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(80, agg_method="mean"),
),
KernelInfo(
F.elastic_video,
......@@ -1364,7 +1472,8 @@ def sample_inputs_center_crop_image_tensor():
def reference_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product(
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
_CENTER_CROP_OUTPUT_SIZES,
):
yield ArgsKwargs(image_loader, output_size=output_size)
......@@ -1405,7 +1514,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
......@@ -1422,7 +1531,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
......@@ -1459,10 +1568,7 @@ KERNEL_INFOS.extend(
KernelInfo(
F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
},
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
......@@ -1471,7 +1577,7 @@ KERNEL_INFOS.extend(
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)
......@@ -1506,7 +1612,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8, torch.float32],
[torch.uint8],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
......@@ -1550,8 +1656,8 @@ KERNEL_INFOS.extend(
kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil),
float32_vs_uint8=True,
reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.equalize_video,
......@@ -1570,7 +1676,7 @@ def sample_inputs_invert_image_tensor():
def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader)
......@@ -1588,7 +1694,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
),
KernelInfo(
F.invert_video,
......@@ -1610,7 +1716,9 @@ def sample_inputs_posterize_image_tensor():
def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_POSTERIZE_BITS,
):
yield ArgsKwargs(image_loader, bits=bits)
......@@ -1629,7 +1737,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.posterize_video,
......@@ -1654,12 +1763,16 @@ def sample_inputs_solarize_image_tensor():
def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold)
def uint8_to_float32_threshold_adapter(other_args, kwargs):
return other_args, dict(threshold=kwargs["threshold"] / 255)
def sample_inputs_solarize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))
......@@ -1673,7 +1786,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=uint8_to_float32_threshold_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.solarize_video,
......@@ -1692,7 +1806,7 @@ def sample_inputs_autocontrast_image_tensor():
def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader)
......@@ -1710,7 +1824,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.autocontrast_video,
......@@ -1732,7 +1850,9 @@ def sample_inputs_adjust_sharpness_image_tensor():
def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SHARPNESS_FACTORS,
):
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
......@@ -1751,7 +1871,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(2),
),
KernelInfo(
F.adjust_sharpness_video,
......@@ -1803,7 +1924,9 @@ def sample_inputs_adjust_brightness_image_tensor():
def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_BRIGHTNESS_FACTORS,
):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
......@@ -1822,7 +1945,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.adjust_brightness_video,
......@@ -1844,7 +1968,9 @@ def sample_inputs_adjust_contrast_image_tensor():
def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_CONTRAST_FACTORS,
):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
......@@ -1863,7 +1989,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
},
),
KernelInfo(
F.adjust_contrast_video,
......@@ -1888,7 +2018,9 @@ def sample_inputs_adjust_gamma_image_tensor():
def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_GAMMA_GAMMAS_GAINS,
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
......@@ -1908,7 +2040,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.adjust_gamma_video,
......@@ -1930,7 +2066,9 @@ def sample_inputs_adjust_hue_image_tensor():
def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_HUE_FACTORS,
):
yield ArgsKwargs(image_loader, hue_factor=hue_factor)
......@@ -1949,7 +2087,12 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
# TODO: investigate
**pil_reference_pixel_difference(20),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.adjust_hue_video,
......@@ -1970,7 +2113,9 @@ def sample_inputs_adjust_saturation_image_tensor():
def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SATURATION_FACTORS,
):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
......@@ -1989,7 +2134,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
},
),
KernelInfo(
F.adjust_saturation_video,
......@@ -2038,7 +2187,9 @@ def sample_inputs_five_crop_image_tensor():
def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size)
......@@ -2060,7 +2211,9 @@ def sample_inputs_ten_crop_image_tensor():
def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
......@@ -2070,6 +2223,17 @@ def sample_inputs_ten_crop_video():
yield ArgsKwargs(video_loader, size=size)
def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)(
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
for output_pil in output
)
return wrapper
_common_five_ten_crop_marks = [
xfail_jit_python_scalar_arg("size"),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
......@@ -2080,10 +2244,9 @@ KERNEL_INFOS.extend(
KernelInfo(
F.five_crop_image_tensor,
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.five_crop_video,
......@@ -2093,10 +2256,9 @@ KERNEL_INFOS.extend(
KernelInfo(
F.ten_crop_image_tensor,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.ten_crop_video,
......
......@@ -244,16 +244,19 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, threshold=0.99),
],
),
ConsistencyConfig(
prototype_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
*[
ConsistencyConfig(
prototype_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness,
......@@ -1007,7 +1010,7 @@ class TestRefSegTransforms:
dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask),
)
......@@ -1021,12 +1024,16 @@ class TestRefSegTransforms:
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
self.set_seed()
output = t(dp)
actual = actual_image, actual_mask = t(dp)
self.set_seed()
expected_output = t_ref(*dp_ref)
expected_image, expected_mask = t_ref(*dp_ref)
if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
expected_image = legacy_F.pil_to_tensor(expected_image)
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
expected = (expected_image, expected_mask)
assert_equal(output, expected_output)
assert_equal(actual, expected)
@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
......
......@@ -11,7 +11,7 @@ import pytest
import torch
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
......@@ -22,6 +22,10 @@ from torchvision.prototype.transforms.functional._meta import convert_format_bou
from torchvision.transforms.functional import _get_perspective_coeffs
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}
@cache
def script(fn):
try:
......@@ -127,6 +131,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
def _unbatch(self, batch, *, data_dims):
......@@ -183,6 +188,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
@sample_inputs
......@@ -212,6 +218,7 @@ class TestKernels:
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
@sample_inputs
......@@ -237,8 +244,35 @@ class TestKernels:
assert_close(
actual,
expected,
check_dtype=False,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.float32_vs_uint8],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.")
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32),
*adapted_other_args,
**adapted_kwargs,
)
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
......@@ -421,12 +455,12 @@ def test_alias(alias, target):
@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype),
KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device):
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment