Unverified Commit 65769ab7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix prototype transforms tests with set agg_method (#6934)

* fix prototype transforms tests with set agg_method

* use individual tolerances

* refactor PIL reference test

* increase tolerance for elastic_mask

* fix autocontrast tolerances

* increase tolerance for RandomAutocontrast
parent d72e9064
...@@ -12,17 +12,9 @@ import torch ...@@ -12,17 +12,9 @@ import torch
import torch.testing import torch.testing
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import ( from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
assert_equal as _assert_equal,
BooleanPair,
ErrorMeta,
NonePair,
NumberPair,
TensorLikePair,
UnsupportedInputs,
)
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = [ __all__ = [
...@@ -54,7 +46,7 @@ __all__ = [ ...@@ -54,7 +46,7 @@ __all__ = [
] ]
class PILImagePair(TensorLikePair): class ImagePair(TensorLikePair):
def __init__( def __init__(
self, self,
actual, actual,
...@@ -64,44 +56,13 @@ class PILImagePair(TensorLikePair): ...@@ -64,44 +56,13 @@ class PILImagePair(TensorLikePair):
allowed_percentage_diff=None, allowed_percentage_diff=None,
**other_parameters, **other_parameters,
): ):
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
raise UnsupportedInputs() actual, expected = [to_image_tensor(input) for input in [actual, expected]]
# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
other_parameters["check_device"] = False
super().__init__(actual, expected, **other_parameters) super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff self.allowed_percentage_diff = allowed_percentage_diff
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)
return super()._equalize_attributes(actual, expected)
def compare(self) -> None: def compare(self) -> None:
actual, expected = self.actual, self.expected actual, expected = self.actual, self.expected
...@@ -111,16 +72,24 @@ class PILImagePair(TensorLikePair): ...@@ -111,16 +72,24 @@ class PILImagePair(TensorLikePair):
abs_diff = torch.abs(actual - expected) abs_diff = torch.abs(actual - expected)
if self.allowed_percentage_diff is not None: if self.allowed_percentage_diff is not None:
percentage_diff = (abs_diff != 0).to(torch.float).mean() percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff: if percentage_diff > self.allowed_percentage_diff:
self._make_error_meta(AssertionError, "percentage mismatch") raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
)
if self.agg_method is None: if self.agg_method is None:
super()._compare_values(actual, expected) super()._compare_values(actual, expected)
else: else:
err = self.agg_method(abs_diff.to(torch.float64)) agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if err > self.atol: if agg_abs_diff > self.atol:
self._make_error_meta(AssertionError, "aggregated mismatch") raise self._make_error_meta(
AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)
def assert_close( def assert_close(
...@@ -148,7 +117,7 @@ def assert_close( ...@@ -148,7 +117,7 @@ def assert_close(
NonePair, NonePair,
BooleanPair, BooleanPair,
NumberPair, NumberPair,
PILImagePair, ImagePair,
TensorLikePair, TensorLikePair,
), ),
allow_subclasses=allow_subclasses, allow_subclasses=allow_subclasses,
...@@ -167,6 +136,32 @@ def assert_close( ...@@ -167,6 +136,32 @@ def assert_close(
assert_equal = functools.partial(assert_close, rtol=0, atol=0) assert_equal = functools.partial(assert_close, rtol=0, atol=0)
def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
else:
return repr(obj)
if args or kwargs:
postfix = "\n".join(
[
"",
"Failure happened for the following parameters:",
"",
*[to_str(arg) for arg in args],
*[f"{name}={to_str(kwarg)}" for name, kwarg in kwargs.items()],
]
)
else:
postfix = ""
def wrapper(msg):
return msg + postfix
return wrapper
class ArgsKwargs: class ArgsKwargs:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.args = args self.args = args
...@@ -656,6 +651,13 @@ class InfoBase: ...@@ -656,6 +651,13 @@ class InfoBase:
] ]
def get_closeness_kwargs(self, test_id, *, dtype, device): def get_closeness_kwargs(self, test_id, *, dtype, device):
if not (isinstance(test_id, tuple) and len(test_id) == 2):
msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
if callable(test_id):
msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
else:
msg += f", but got {test_id} instead."
raise pytest.UsageError(msg)
if isinstance(device, torch.device): if isinstance(device, torch.device):
device = device.type device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict()) return self.closeness_kwargs.get((test_id, dtype, device), dict())
...@@ -4,6 +4,7 @@ import itertools ...@@ -4,6 +4,7 @@ import itertools
import math import math
import numpy as np import numpy as np
import PIL.Image
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
...@@ -49,6 +50,12 @@ class KernelInfo(InfoBase): ...@@ -49,6 +50,12 @@ class KernelInfo(InfoBase):
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used. # values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None, reference_inputs_fn=None,
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
# the reference inputs. This is usually used whenever we use a PIL kernel as reference.
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8=False,
# See InfoBase # See InfoBase
test_marks=None, test_marks=None,
# See InfoBase # See InfoBase
...@@ -60,28 +67,64 @@ class KernelInfo(InfoBase): ...@@ -60,28 +67,64 @@ class KernelInfo(InfoBase):
self.reference_fn = reference_fn self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn self.reference_inputs_fn = reference_inputs_fn
if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8
DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
}
CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE = { def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, agg_method=None):
(("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): dict(atol=atol, rtol=0) return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, agg_method=agg_method)
for dtype, atol in [(torch.uint8, 1), (torch.float32, 1 / 255)]
}
def cuda_vs_cpu_pixel_difference(atol=1):
return {
(("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): _pixel_difference_closeness_kwargs(atol, dtype=dtype)
for dtype in [torch.uint8, torch.float32]
}
def pil_reference_pixel_difference(atol=1, agg_method=None):
return {
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs(
atol, agg_method=agg_method
)
}
def float32_vs_uint8_pixel_difference(atol=1, agg_method=None):
return {
(
("TestKernels", "test_float32_vs_uint8"),
torch.float32,
"cpu",
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, agg_method=agg_method)
}
def pil_reference_wrapper(pil_kernel): def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel) @functools.wraps(pil_kernel)
def wrapper(image_tensor, *other_args, **kwargs): def wrapper(input_tensor, *other_args, **kwargs):
if image_tensor.ndim > 3: if input_tensor.dtype != torch.uint8:
raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
if input_tensor.ndim > 3:
raise pytest.UsageError( raise pytest.UsageError(
f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}" f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
) )
# We don't need to convert back to tensor here, since `assert_close` does that automatically. input_pil = F.to_image_pil(input_tensor)
return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs) output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image):
return output_pil
output_tensor = F.to_image_tensor(output_pil)
# 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3:
output_tensor = output_tensor.unsqueeze(0)
elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
output_tensor = output_tensor.squeeze(0)
return output_tensor
return wrapper return wrapper
...@@ -126,7 +169,7 @@ def sample_inputs_horizontal_flip_image_tensor(): ...@@ -126,7 +169,7 @@ def sample_inputs_horizontal_flip_image_tensor():
def reference_inputs_horizontal_flip_image_tensor(): def reference_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()]): for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -180,7 +223,7 @@ KERNEL_INFOS.extend( ...@@ -180,7 +223,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor, sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil), reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor, reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.horizontal_flip_bounding_box, F.horizontal_flip_bounding_box,
...@@ -244,7 +287,7 @@ def reference_resize_image_tensor(*args, **kwargs): ...@@ -244,7 +287,7 @@ def reference_resize_image_tensor(*args, **kwargs):
def reference_inputs_resize_image_tensor(): def reference_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT, F.InterpolationMode.NEAREST_EXACT,
...@@ -324,9 +367,13 @@ KERNEL_INFOS.extend( ...@@ -324,9 +367,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_image_tensor, sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=reference_resize_image_tensor, reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, # TODO: investigate
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, **pil_reference_pixel_difference(110, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
# TODO: investigate
**float32_vs_uint8_pixel_difference(50),
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
...@@ -346,7 +393,8 @@ KERNEL_INFOS.extend( ...@@ -346,7 +393,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_mask, sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask, reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
], ],
...@@ -354,7 +402,7 @@ KERNEL_INFOS.extend( ...@@ -354,7 +402,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.resize_video, F.resize_video,
sample_inputs_fn=sample_inputs_resize_video, sample_inputs_fn=sample_inputs_resize_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, closeness_kwargs=cuda_vs_cpu_pixel_difference(),
), ),
] ]
) )
...@@ -400,6 +448,36 @@ _DIVERSE_AFFINE_PARAMS = [ ...@@ -400,6 +448,36 @@ _DIVERSE_AFFINE_PARAMS = [
] ]
def get_fills(*, num_channels, dtype, vector=True):
yield None
max_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value
yield max_value / 2
yield max_value
if not vector:
return
if dtype.is_floating_point:
yield [0.1 + c / 10 for c in range(num_channels)]
else:
yield [12.0 + c for c in range(num_channels)]
def float32_vs_uint8_fill_adapter(other_args, kwargs):
fill = kwargs.get("fill")
if fill is None:
return other_args, kwargs
if isinstance(fill, (int, float)):
fill /= 255
else:
fill = type(fill)(fill_ / 255 for fill_ in fill)
return other_args, dict(kwargs, fill=fill)
def sample_inputs_affine_image_tensor(): def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial( make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
...@@ -409,10 +487,7 @@ def sample_inputs_affine_image_tensor(): ...@@ -409,10 +487,7 @@ def sample_inputs_affine_image_tensor():
yield ArgsKwargs(image_loader, **affine_params) yield ArgsKwargs(image_loader, **affine_params)
for image_loader in make_affine_image_loaders(): for image_loader in make_affine_image_loaders():
fills = [None, 0.5] for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill) yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill)
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
...@@ -426,7 +501,9 @@ def sample_inputs_affine_image_tensor(): ...@@ -426,7 +501,9 @@ def sample_inputs_affine_image_tensor():
def reference_inputs_affine_image_tensor(): def reference_inputs_affine_image_tensor():
for image_loader, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS): for image_loader, affine_kwargs in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _AFFINE_KWARGS
):
yield ArgsKwargs( yield ArgsKwargs(
image_loader, image_loader,
interpolation=F.InterpolationMode.NEAREST, interpolation=F.InterpolationMode.NEAREST,
...@@ -564,7 +641,8 @@ KERNEL_INFOS.extend( ...@@ -564,7 +641,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_image_tensor, sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil), reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor, reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10, agg_method="mean"),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
...@@ -589,7 +667,8 @@ KERNEL_INFOS.extend( ...@@ -589,7 +667,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_mask, sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask, reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, closeness_kwargs=pil_reference_pixel_difference(10),
float32_vs_uint8=True,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
], ],
...@@ -631,7 +710,9 @@ KERNEL_INFOS.append( ...@@ -631,7 +710,9 @@ KERNEL_INFOS.append(
def sample_inputs_convert_color_space_image_tensor(): def sample_inputs_convert_color_space_image_tensor():
color_spaces = list(set(features.ColorSpace) - {features.ColorSpace.OTHER}) color_spaces = sorted(
set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value
)
for old_color_space, new_color_space in cycle_over(color_spaces): for old_color_space, new_color_space in cycle_over(color_spaces):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True): for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True):
...@@ -659,7 +740,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c ...@@ -659,7 +740,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c
def reference_inputs_convert_color_space_image_tensor(): def reference_inputs_convert_color_space_image_tensor():
for args_kwargs in sample_inputs_convert_color_space_image_tensor(): for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs (image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 3: if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8:
yield args_kwargs yield args_kwargs
...@@ -678,7 +759,10 @@ KERNEL_INFOS.extend( ...@@ -678,7 +759,10 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor, reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
), ),
KernelInfo( KernelInfo(
F.convert_color_space_video, F.convert_color_space_video,
...@@ -694,7 +778,7 @@ def sample_inputs_vertical_flip_image_tensor(): ...@@ -694,7 +778,7 @@ def sample_inputs_vertical_flip_image_tensor():
def reference_inputs_vertical_flip_image_tensor(): def reference_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()]): for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -739,7 +823,7 @@ KERNEL_INFOS.extend( ...@@ -739,7 +823,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor, sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil), reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor, reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.vertical_flip_bounding_box, F.vertical_flip_bounding_box,
...@@ -775,10 +859,7 @@ def sample_inputs_rotate_image_tensor(): ...@@ -775,10 +859,7 @@ def sample_inputs_rotate_image_tensor():
yield ArgsKwargs(image_loader, angle=15.0, center=center) yield ArgsKwargs(image_loader, angle=15.0, center=center)
for image_loader in make_rotate_image_loaders(): for image_loader in make_rotate_image_loaders():
fills = [None, 0.5] for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
yield ArgsKwargs(image_loader, angle=15.0, fill=fill) yield ArgsKwargs(image_loader, angle=15.0, fill=fill)
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
...@@ -789,7 +870,9 @@ def sample_inputs_rotate_image_tensor(): ...@@ -789,7 +870,9 @@ def sample_inputs_rotate_image_tensor():
def reference_inputs_rotate_image_tensor(): def reference_inputs_rotate_image_tensor():
for image_loader, angle in itertools.product(make_image_loaders(extra_dims=[()]), _ROTATE_ANGLES): for image_loader, angle in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _ROTATE_ANGLES
):
yield ArgsKwargs(image_loader, angle=angle) yield ArgsKwargs(image_loader, angle=angle)
...@@ -830,7 +913,9 @@ KERNEL_INFOS.extend( ...@@ -830,7 +913,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_image_tensor, sample_inputs_fn=sample_inputs_rotate_image_tensor,
reference_fn=pil_reference_wrapper(F.rotate_image_pil), reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor, reference_inputs_fn=reference_inputs_rotate_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
...@@ -846,7 +931,8 @@ KERNEL_INFOS.extend( ...@@ -846,7 +931,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_mask, sample_inputs_fn=sample_inputs_rotate_mask,
reference_fn=reference_rotate_mask, reference_fn=reference_rotate_mask,
reference_inputs_fn=reference_inputs_rotate_mask, reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
), ),
KernelInfo( KernelInfo(
F.rotate_video, F.rotate_video,
...@@ -873,7 +959,9 @@ def sample_inputs_crop_image_tensor(): ...@@ -873,7 +959,9 @@ def sample_inputs_crop_image_tensor():
def reference_inputs_crop_image_tensor(): def reference_inputs_crop_image_tensor():
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _CROP_PARAMS): for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS
):
yield ArgsKwargs(image_loader, **params) yield ArgsKwargs(image_loader, **params)
...@@ -928,7 +1016,7 @@ KERNEL_INFOS.extend( ...@@ -928,7 +1016,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_image_tensor, sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor, reference_inputs_fn=reference_inputs_crop_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.crop_bounding_box, F.crop_bounding_box,
...@@ -941,7 +1029,7 @@ KERNEL_INFOS.extend( ...@@ -941,7 +1029,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_mask, sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask, reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.crop_video, F.crop_video,
...@@ -970,7 +1058,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs): ...@@ -970,7 +1058,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
def reference_inputs_resized_crop_image_tensor(): def reference_inputs_resized_crop_image_tensor():
for image_loader, interpolation, params in itertools.product( for image_loader, interpolation, params in itertools.product(
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT, F.InterpolationMode.NEAREST_EXACT,
...@@ -1020,9 +1108,13 @@ KERNEL_INFOS.extend( ...@@ -1020,9 +1108,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_image_tensor, sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor, reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor, reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, # TODO: investigate
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, **pil_reference_pixel_difference(60, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
# TODO: investigate
**float32_vs_uint8_pixel_difference(50),
}, },
), ),
KernelInfo( KernelInfo(
...@@ -1034,12 +1126,13 @@ KERNEL_INFOS.extend( ...@@ -1034,12 +1126,13 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_mask, sample_inputs_fn=sample_inputs_resized_crop_mask,
reference_fn=pil_reference_wrapper(F.resized_crop_image_pil), reference_fn=pil_reference_wrapper(F.resized_crop_image_pil),
reference_inputs_fn=reference_inputs_resized_crop_mask, reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10),
), ),
KernelInfo( KernelInfo(
F.resized_crop_video, F.resized_crop_video,
sample_inputs_fn=sample_inputs_resized_crop_video, sample_inputs_fn=sample_inputs_resized_crop_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, closeness_kwargs=cuda_vs_cpu_pixel_difference(),
), ),
] ]
) )
...@@ -1062,10 +1155,7 @@ def sample_inputs_pad_image_tensor(): ...@@ -1062,10 +1155,7 @@ def sample_inputs_pad_image_tensor():
yield ArgsKwargs(image_loader, padding=padding) yield ArgsKwargs(image_loader, padding=padding)
for image_loader in make_pad_image_loaders(): for image_loader in make_pad_image_loaders():
fills = [None, 0.5] for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
yield ArgsKwargs(image_loader, padding=[1], fill=fill) yield ArgsKwargs(image_loader, padding=[1], fill=fill)
for image_loader, padding_mode in itertools.product( for image_loader, padding_mode in itertools.product(
...@@ -1082,12 +1172,15 @@ def sample_inputs_pad_image_tensor(): ...@@ -1082,12 +1172,15 @@ def sample_inputs_pad_image_tensor():
def reference_inputs_pad_image_tensor(): def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _PAD_PARAMS): for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
fills = [None, 128.0, 128] for fill in get_fills(
if params["padding_mode"] == "constant": num_channels=image_loader.num_channels,
fills.append([12.0 + c for c in range(image_loader.num_channels)]) dtype=image_loader.dtype,
for fill in fills: vector=params["padding_mode"] == "constant",
):
yield ArgsKwargs(image_loader, fill=fill, **params) yield ArgsKwargs(image_loader, fill=fill, **params)
...@@ -1110,8 +1203,10 @@ def sample_inputs_pad_mask(): ...@@ -1110,8 +1203,10 @@ def sample_inputs_pad_mask():
def reference_inputs_pad_mask(): def reference_inputs_pad_mask():
for image_loader, fill, params in itertools.product(make_image_loaders(extra_dims=[()]), [None, 127], _PAD_PARAMS): for mask_loader, fill, params in itertools.product(
yield ArgsKwargs(image_loader, fill=fill, **params) make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
):
yield ArgsKwargs(mask_loader, fill=fill, **params)
def sample_inputs_pad_video(): def sample_inputs_pad_video():
...@@ -1158,7 +1253,8 @@ KERNEL_INFOS.extend( ...@@ -1158,7 +1253,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_image_tensor, sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor, reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
...@@ -1180,7 +1276,7 @@ KERNEL_INFOS.extend( ...@@ -1180,7 +1276,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_mask, sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask, reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=float32_vs_uint8_fill_adapter,
), ),
KernelInfo( KernelInfo(
F.pad_video, F.pad_video,
...@@ -1197,14 +1293,16 @@ _PERSPECTIVE_COEFFS = [ ...@@ -1197,14 +1293,16 @@ _PERSPECTIVE_COEFFS = [
def sample_inputs_perspective_image_tensor(): def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=["random"]):
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0])
def reference_inputs_perspective_image_tensor(): def reference_inputs_perspective_image_tensor():
for image_loader, coefficients in itertools.product(make_image_loaders(extra_dims=[()]), _PERSPECTIVE_COEFFS): for image_loader, coefficients in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PERSPECTIVE_COEFFS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in [None, 128.0, 128, [12.0 + c for c in range(image_loader.num_channels)]]: for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=coefficients) yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=coefficients)
...@@ -1239,9 +1337,12 @@ KERNEL_INFOS.extend( ...@@ -1239,9 +1337,12 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_image_tensor, sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={ closeness_kwargs={
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, # TODO: investigate
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, **pil_reference_pixel_difference(160, agg_method="mean"),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
}, },
), ),
KernelInfo( KernelInfo(
...@@ -1253,12 +1354,15 @@ KERNEL_INFOS.extend( ...@@ -1253,12 +1354,15 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_mask, sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
},
), ),
KernelInfo( KernelInfo(
F.perspective_video, F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video, sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, closeness_kwargs=cuda_vs_cpu_pixel_difference(),
), ),
] ]
) )
...@@ -1271,13 +1375,13 @@ def _get_elastic_displacement(spatial_size): ...@@ -1271,13 +1375,13 @@ def _get_elastic_displacement(spatial_size):
def sample_inputs_elastic_image_tensor(): def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=["random"]):
displacement = _get_elastic_displacement(image_loader.spatial_size) displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
def reference_inputs_elastic_image_tensor(): def reference_inputs_elastic_image_tensor():
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
...@@ -1285,7 +1389,7 @@ def reference_inputs_elastic_image_tensor(): ...@@ -1285,7 +1389,7 @@ def reference_inputs_elastic_image_tensor():
], ],
): ):
displacement = _get_elastic_displacement(image_loader.spatial_size) displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
...@@ -1324,7 +1428,9 @@ KERNEL_INFOS.extend( ...@@ -1324,7 +1428,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_image_tensor, sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=float32_vs_uint8_fill_adapter,
# TODO: investigate
closeness_kwargs=float32_vs_uint8_pixel_difference(60, agg_method="mean"),
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_box,
...@@ -1335,7 +1441,9 @@ KERNEL_INFOS.extend( ...@@ -1335,7 +1441,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_mask, sample_inputs_fn=sample_inputs_elastic_mask,
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask, reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(80, agg_method="mean"),
), ),
KernelInfo( KernelInfo(
F.elastic_video, F.elastic_video,
...@@ -1364,7 +1472,8 @@ def sample_inputs_center_crop_image_tensor(): ...@@ -1364,7 +1472,8 @@ def sample_inputs_center_crop_image_tensor():
def reference_inputs_center_crop_image_tensor(): def reference_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product( for image_loader, output_size in itertools.product(
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
_CENTER_CROP_OUTPUT_SIZES,
): ):
yield ArgsKwargs(image_loader, output_size=output_size) yield ArgsKwargs(image_loader, output_size=output_size)
...@@ -1405,7 +1514,7 @@ KERNEL_INFOS.extend( ...@@ -1405,7 +1514,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_image_tensor, sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor, reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -1422,7 +1531,7 @@ KERNEL_INFOS.extend( ...@@ -1422,7 +1531,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_mask, sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask, reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -1459,10 +1568,7 @@ KERNEL_INFOS.extend( ...@@ -1459,10 +1568,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.gaussian_blur_image_tensor, F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs={ closeness_kwargs=cuda_vs_cpu_pixel_difference(),
**DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
**CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE,
},
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
...@@ -1471,7 +1577,7 @@ KERNEL_INFOS.extend( ...@@ -1471,7 +1577,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.gaussian_blur_video, F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video, sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE, closeness_kwargs=cuda_vs_cpu_pixel_difference(),
), ),
] ]
) )
...@@ -1506,7 +1612,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1506,7 +1612,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size = (256, 256) spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product( for dtype, color_space, fn in itertools.product(
[torch.uint8, torch.float32], [torch.uint8],
[features.ColorSpace.GRAY, features.ColorSpace.RGB], [features.ColorSpace.GRAY, features.ColorSpace.RGB],
[ [
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
...@@ -1550,8 +1656,8 @@ KERNEL_INFOS.extend( ...@@ -1550,8 +1656,8 @@ KERNEL_INFOS.extend(
kernel_name="equalize_image_tensor", kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor, sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil), reference_fn=pil_reference_wrapper(F.equalize_image_pil),
float32_vs_uint8=True,
reference_inputs_fn=reference_inputs_equalize_image_tensor, reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.equalize_video, F.equalize_video,
...@@ -1570,7 +1676,7 @@ def sample_inputs_invert_image_tensor(): ...@@ -1570,7 +1676,7 @@ def sample_inputs_invert_image_tensor():
def reference_inputs_invert_image_tensor(): def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1588,7 +1694,7 @@ KERNEL_INFOS.extend( ...@@ -1588,7 +1694,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_invert_image_tensor, sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil), reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor, reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.invert_video, F.invert_video,
...@@ -1610,7 +1716,9 @@ def sample_inputs_posterize_image_tensor(): ...@@ -1610,7 +1716,9 @@ def sample_inputs_posterize_image_tensor():
def reference_inputs_posterize_image_tensor(): def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product( for image_loader, bits in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_POSTERIZE_BITS, _POSTERIZE_BITS,
): ):
yield ArgsKwargs(image_loader, bits=bits) yield ArgsKwargs(image_loader, bits=bits)
...@@ -1629,7 +1737,8 @@ KERNEL_INFOS.extend( ...@@ -1629,7 +1737,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_posterize_image_tensor, sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil), reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor, reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
), ),
KernelInfo( KernelInfo(
F.posterize_video, F.posterize_video,
...@@ -1654,12 +1763,16 @@ def sample_inputs_solarize_image_tensor(): ...@@ -1654,12 +1763,16 @@ def sample_inputs_solarize_image_tensor():
def reference_inputs_solarize_image_tensor(): def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
for threshold in _get_solarize_thresholds(image_loader.dtype): for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold) yield ArgsKwargs(image_loader, threshold=threshold)
def uint8_to_float32_threshold_adapter(other_args, kwargs):
return other_args, dict(threshold=kwargs["threshold"] / 255)
def sample_inputs_solarize_video(): def sample_inputs_solarize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype))) yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))
...@@ -1673,7 +1786,8 @@ KERNEL_INFOS.extend( ...@@ -1673,7 +1786,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_solarize_image_tensor, sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil), reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor, reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=uint8_to_float32_threshold_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
), ),
KernelInfo( KernelInfo(
F.solarize_video, F.solarize_video,
...@@ -1692,7 +1806,7 @@ def sample_inputs_autocontrast_image_tensor(): ...@@ -1692,7 +1806,7 @@ def sample_inputs_autocontrast_image_tensor():
def reference_inputs_autocontrast_image_tensor(): def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1710,7 +1824,11 @@ KERNEL_INFOS.extend( ...@@ -1710,7 +1824,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_autocontrast_image_tensor, sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor, reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
), ),
KernelInfo( KernelInfo(
F.autocontrast_video, F.autocontrast_video,
...@@ -1732,7 +1850,9 @@ def sample_inputs_adjust_sharpness_image_tensor(): ...@@ -1732,7 +1850,9 @@ def sample_inputs_adjust_sharpness_image_tensor():
def reference_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product( for image_loader, sharpness_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SHARPNESS_FACTORS, _ADJUST_SHARPNESS_FACTORS,
): ):
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
...@@ -1751,7 +1871,8 @@ KERNEL_INFOS.extend( ...@@ -1751,7 +1871,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(2),
), ),
KernelInfo( KernelInfo(
F.adjust_sharpness_video, F.adjust_sharpness_video,
...@@ -1803,7 +1924,9 @@ def sample_inputs_adjust_brightness_image_tensor(): ...@@ -1803,7 +1924,9 @@ def sample_inputs_adjust_brightness_image_tensor():
def reference_inputs_adjust_brightness_image_tensor(): def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product( for image_loader, brightness_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_BRIGHTNESS_FACTORS, _ADJUST_BRIGHTNESS_FACTORS,
): ):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
...@@ -1822,7 +1945,8 @@ KERNEL_INFOS.extend( ...@@ -1822,7 +1945,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
), ),
KernelInfo( KernelInfo(
F.adjust_brightness_video, F.adjust_brightness_video,
...@@ -1844,7 +1968,9 @@ def sample_inputs_adjust_contrast_image_tensor(): ...@@ -1844,7 +1968,9 @@ def sample_inputs_adjust_contrast_image_tensor():
def reference_inputs_adjust_contrast_image_tensor(): def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product( for image_loader, contrast_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_CONTRAST_FACTORS, _ADJUST_CONTRAST_FACTORS,
): ):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
...@@ -1863,7 +1989,11 @@ KERNEL_INFOS.extend( ...@@ -1863,7 +1989,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
},
), ),
KernelInfo( KernelInfo(
F.adjust_contrast_video, F.adjust_contrast_video,
...@@ -1888,7 +2018,9 @@ def sample_inputs_adjust_gamma_image_tensor(): ...@@ -1888,7 +2018,9 @@ def sample_inputs_adjust_gamma_image_tensor():
def reference_inputs_adjust_gamma_image_tensor(): def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product( for image_loader, (gamma, gain) in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_GAMMA_GAMMAS_GAINS, _ADJUST_GAMMA_GAMMAS_GAINS,
): ):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
...@@ -1908,7 +2040,11 @@ KERNEL_INFOS.extend( ...@@ -1908,7 +2040,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
), ),
KernelInfo( KernelInfo(
F.adjust_gamma_video, F.adjust_gamma_video,
...@@ -1930,7 +2066,9 @@ def sample_inputs_adjust_hue_image_tensor(): ...@@ -1930,7 +2066,9 @@ def sample_inputs_adjust_hue_image_tensor():
def reference_inputs_adjust_hue_image_tensor(): def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product( for image_loader, hue_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_HUE_FACTORS, _ADJUST_HUE_FACTORS,
): ):
yield ArgsKwargs(image_loader, hue_factor=hue_factor) yield ArgsKwargs(image_loader, hue_factor=hue_factor)
...@@ -1949,7 +2087,12 @@ KERNEL_INFOS.extend( ...@@ -1949,7 +2087,12 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
# TODO: investigate
**pil_reference_pixel_difference(20),
**float32_vs_uint8_pixel_difference(),
},
), ),
KernelInfo( KernelInfo(
F.adjust_hue_video, F.adjust_hue_video,
...@@ -1970,7 +2113,9 @@ def sample_inputs_adjust_saturation_image_tensor(): ...@@ -1970,7 +2113,9 @@ def sample_inputs_adjust_saturation_image_tensor():
def reference_inputs_adjust_saturation_image_tensor(): def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product( for image_loader, saturation_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SATURATION_FACTORS, _ADJUST_SATURATION_FACTORS,
): ):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
...@@ -1989,7 +2134,11 @@ KERNEL_INFOS.extend( ...@@ -1989,7 +2134,11 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(2),
},
), ),
KernelInfo( KernelInfo(
F.adjust_saturation_video, F.adjust_saturation_video,
...@@ -2038,7 +2187,9 @@ def sample_inputs_five_crop_image_tensor(): ...@@ -2038,7 +2187,9 @@ def sample_inputs_five_crop_image_tensor():
def reference_inputs_five_crop_image_tensor(): def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]): for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
...@@ -2060,7 +2211,9 @@ def sample_inputs_ten_crop_image_tensor(): ...@@ -2060,7 +2211,9 @@ def sample_inputs_ten_crop_image_tensor():
def reference_inputs_ten_crop_image_tensor(): def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]): for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
...@@ -2070,6 +2223,17 @@ def sample_inputs_ten_crop_video(): ...@@ -2070,6 +2223,17 @@ def sample_inputs_ten_crop_video():
yield ArgsKwargs(video_loader, size=size) yield ArgsKwargs(video_loader, size=size)
def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)(
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
for output_pil in output
)
return wrapper
_common_five_ten_crop_marks = [ _common_five_ten_crop_marks = [
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
...@@ -2080,10 +2244,9 @@ KERNEL_INFOS.extend( ...@@ -2080,10 +2244,9 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.five_crop_image_tensor, F.five_crop_image_tensor,
sample_inputs_fn=sample_inputs_five_crop_image_tensor, sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor, reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.five_crop_video, F.five_crop_video,
...@@ -2093,10 +2256,9 @@ KERNEL_INFOS.extend( ...@@ -2093,10 +2256,9 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.ten_crop_image_tensor, F.ten_crop_image_tensor,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor, sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.ten_crop_video, F.ten_crop_video,
......
...@@ -244,16 +244,19 @@ CONSISTENCY_CONFIGS = [ ...@@ -244,16 +244,19 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, threshold=0.99), ArgsKwargs(p=1, threshold=0.99),
], ],
), ),
ConsistencyConfig( *[
prototype_transforms.RandomAutocontrast, ConsistencyConfig(
legacy_transforms.RandomAutocontrast, prototype_transforms.RandomAutocontrast,
[ legacy_transforms.RandomAutocontrast,
ArgsKwargs(p=0), [
ArgsKwargs(p=1), ArgsKwargs(p=0),
], ArgsKwargs(p=1),
# Use default tolerances of `torch.testing.assert_close` ],
closeness_kwargs=dict(rtol=None, atol=None), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
), closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness, prototype_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness, legacy_transforms.RandomAdjustSharpness,
...@@ -1007,7 +1010,7 @@ class TestRefSegTransforms: ...@@ -1007,7 +1010,7 @@ class TestRefSegTransforms:
dp = (conv_fn(feature_image), feature_mask) dp = (conv_fn(feature_image), feature_mask)
dp_ref = ( dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask), to_image_pil(feature_mask),
) )
...@@ -1021,12 +1024,16 @@ class TestRefSegTransforms: ...@@ -1021,12 +1024,16 @@ class TestRefSegTransforms:
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
self.set_seed() self.set_seed()
output = t(dp) actual = actual_image, actual_mask = t(dp)
self.set_seed() self.set_seed()
expected_output = t_ref(*dp_ref) expected_image, expected_mask = t_ref(*dp_ref)
if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
expected_image = legacy_F.pil_to_tensor(expected_image)
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
expected = (expected_image, expected_mask)
assert_equal(output, expected_output) assert_equal(actual, expected)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"), ("t_ref", "t", "data_kwargs"),
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
import torch import torch
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -22,6 +22,10 @@ from torchvision.prototype.transforms.functional._meta import convert_format_bou ...@@ -22,6 +22,10 @@ from torchvision.prototype.transforms.functional._meta import convert_format_bou
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}
@cache @cache
def script(fn): def script(fn):
try: try:
...@@ -127,6 +131,7 @@ class TestKernels: ...@@ -127,6 +131,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
...@@ -183,6 +188,7 @@ class TestKernels: ...@@ -183,6 +188,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device), **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
@sample_inputs @sample_inputs
...@@ -212,6 +218,7 @@ class TestKernels: ...@@ -212,6 +218,7 @@ class TestKernels:
output_cpu, output_cpu,
check_device=False, check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device), **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
@sample_inputs @sample_inputs
...@@ -237,8 +244,35 @@ class TestKernels: ...@@ -237,8 +244,35 @@ class TestKernels:
assert_close( assert_close(
actual, actual,
expected, expected,
check_dtype=False,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.float32_vs_uint8],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.")
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32),
*adapted_other_args,
**adapted_kwargs,
)
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
...@@ -421,12 +455,12 @@ def test_alias(alias, target): ...@@ -421,12 +455,12 @@ def test_alias(alias, target):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("info", "args_kwargs"), ("info", "args_kwargs"),
make_info_args_kwargs_params( make_info_args_kwargs_params(
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(), args_kwargs_fn=lambda info: info.sample_inputs_fn(),
), ),
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device): def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment