Unverified Commit 35b0b9ee authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve prototype transforms kernel tests (#6596)

* fix PIL and tensor mask comparison

* introduce kernel_name field

* add dtype consistency test

* port some tests from old framework

* add kernel infos for conversion kernels

* cleanup

* use nearest and bicubic for resize image sample inputs

* make parametrization id more obvious

* use named sentinel instead of None for random image size
parent c0911e31
...@@ -14,6 +14,7 @@ from torch.nn.functional import one_hot ...@@ -14,6 +14,7 @@ from torch.nn.functional import one_hot
from torch.testing._comparison import ( from torch.testing._comparison import (
assert_equal as _assert_equal, assert_equal as _assert_equal,
BooleanPair, BooleanPair,
ErrorMeta,
NonePair, NonePair,
NumberPair, NumberPair,
TensorLikePair, TensorLikePair,
...@@ -70,6 +71,19 @@ class PILImagePair(TensorLikePair): ...@@ -70,6 +71,19 @@ class PILImagePair(TensorLikePair):
actual, expected = [ actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected] to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
] ]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses) return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
def _equalize_attributes(self, actual, expected): def _equalize_attributes(self, actual, expected):
...@@ -165,12 +179,12 @@ class ArgsKwargs: ...@@ -165,12 +179,12 @@ class ArgsKwargs:
DEFAULT_SQUARE_IMAGE_SIZE = 15 DEFAULT_SQUARE_IMAGE_SIZE = 15
DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33) DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33)
DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9) DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9)
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, None) DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random")
def _parse_image_size(size, *, name="size"): def _parse_image_size(size, *, name="size"):
if size is None: if size == "random":
return tuple(torch.randint(16, 33, (2,)).tolist()) return tuple(torch.randint(15, 33, (2,)).tolist())
elif isinstance(size, int) and size > 0: elif isinstance(size, int) and size > 0:
return (size, size) return (size, size)
elif ( elif (
...@@ -181,8 +195,8 @@ def _parse_image_size(size, *, name="size"): ...@@ -181,8 +195,8 @@ def _parse_image_size(size, *, name="size"):
return tuple(size) return tuple(size)
else: else:
raise pytest.UsageError( raise pytest.UsageError(
f"'{name}' can either be `None`, a positive integer, or a sequence of two positive integers," f"'{name}' can either be `'random'`, a positive integer, or a sequence of two positive integers,"
f"but got {size} instead" f"but got {size} instead."
) )
...@@ -228,7 +242,7 @@ class ImageLoader(TensorLoader): ...@@ -228,7 +242,7 @@ class ImageLoader(TensorLoader):
def make_image_loader( def make_image_loader(
size=None, size="random",
*, *,
color_space=features.ColorSpace.RGB, color_space=features.ColorSpace.RGB,
extra_dims=(), extra_dims=(),
...@@ -298,7 +312,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): ...@@ -298,7 +312,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
).reshape(low.shape) ).reshape(low.shape)
def make_bounding_box_loader(*, extra_dims=(), format, image_size=None, dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = features.BoundingBoxFormat[format]
if format not in { if format not in {
...@@ -355,7 +369,7 @@ def make_bounding_box_loaders( ...@@ -355,7 +369,7 @@ def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(features.BoundingBoxFormat), formats=tuple(features.BoundingBoxFormat),
image_size=None, image_size="random",
dtypes=(torch.float32, torch.int64), dtypes=(torch.float32, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
...@@ -440,10 +454,10 @@ class MaskLoader(TensorLoader): ...@@ -440,10 +454,10 @@ class MaskLoader(TensorLoader):
pass pass
def make_detection_mask_loader(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_image_size(size) size = _parse_image_size(size)
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ())) num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device)
...@@ -457,7 +471,7 @@ make_detection_mask = from_loader(make_detection_mask_loader) ...@@ -457,7 +471,7 @@ make_detection_mask = from_loader(make_detection_mask_loader)
def make_detection_mask_loaders( def make_detection_mask_loaders(
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_IMAGE_SIZES,
num_objects=(1, 0, None), num_objects=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
...@@ -468,10 +482,10 @@ def make_detection_mask_loaders( ...@@ -468,10 +482,10 @@ def make_detection_mask_loaders(
make_detection_masks = from_loaders(make_detection_mask_loaders) make_detection_masks = from_loaders(make_detection_mask_loaders)
def make_segmentation_mask_loader(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8): def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = _parse_image_size(size) size = _parse_image_size(size)
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ())) num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device)
...@@ -486,7 +500,7 @@ make_segmentation_mask = from_loader(make_segmentation_mask_loader) ...@@ -486,7 +500,7 @@ make_segmentation_mask = from_loader(make_segmentation_mask_loader)
def make_segmentation_mask_loaders( def make_segmentation_mask_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_IMAGE_SIZES,
num_categories=(1, 2, None), num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
...@@ -500,8 +514,8 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders) ...@@ -500,8 +514,8 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders)
def make_mask_loaders( def make_mask_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_IMAGE_SIZES,
num_objects=(1, 0, None), num_objects=(1, 0, "random"),
num_categories=(1, 2, None), num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
......
...@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Iterable, Optional ...@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Iterable, Optional
import numpy as np import numpy as np
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
...@@ -22,6 +23,9 @@ class KernelInfo: ...@@ -22,6 +23,9 @@ class KernelInfo:
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
# not include extensive parameter combinations to keep to overall test count moderate. # not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]] sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name: Optional[str] = None
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen # tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a # inside the function. It should return a tensor or to be more precise an object that can be compared to a
...@@ -34,6 +38,7 @@ class KernelInfo: ...@@ -34,6 +38,7 @@ class KernelInfo:
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
...@@ -62,7 +67,7 @@ KERNEL_INFOS = [] ...@@ -62,7 +67,7 @@ KERNEL_INFOS = []
def sample_inputs_horizontal_flip_image_tensor(): def sample_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(dtypes=[torch.float32]): for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -72,14 +77,16 @@ def reference_inputs_horizontal_flip_image_tensor(): ...@@ -72,14 +77,16 @@ def reference_inputs_horizontal_flip_image_tensor():
def sample_inputs_horizontal_flip_bounding_box(): def sample_inputs_horizontal_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders(
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size
) )
def sample_inputs_horizontal_flip_mask(): def sample_inputs_horizontal_flip_mask():
for image_loader in make_mask_loaders(dtypes=[torch.uint8]): for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -87,6 +94,7 @@ KERNEL_INFOS.extend( ...@@ -87,6 +94,7 @@ KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.horizontal_flip_image_tensor, F.horizontal_flip_image_tensor,
kernel_name="horizontal_flip_image_tensor",
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor, sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil), reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor, reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
...@@ -104,23 +112,34 @@ KERNEL_INFOS.extend( ...@@ -104,23 +112,34 @@ KERNEL_INFOS.extend(
) )
def _get_resize_sizes(image_size):
height, width = image_size
yield height, width
yield int(height * 0.75), int(width * 1.25)
def sample_inputs_resize_image_tensor(): def sample_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(dtypes=[torch.float32]), make_image_loaders(dtypes=[torch.float32]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
height, width = image_loader.image_size for size in _get_resize_sizes(image_loader.image_size):
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield ArgsKwargs(image_loader, size=size, interpolation=interpolation) yield ArgsKwargs(image_loader, size=size, interpolation=interpolation)
@pil_reference_wrapper
def reference_resize_image_tensor(*args, **kwargs):
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
}:
raise pytest.UsageError("Anti-aliasing is always active in PIL")
return F.resize_image_pil(*args, **kwargs)
def reference_inputs_resize_image_tensor(): def reference_inputs_resize_image_tensor():
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()]),
...@@ -130,30 +149,48 @@ def reference_inputs_resize_image_tensor(): ...@@ -130,30 +149,48 @@ def reference_inputs_resize_image_tensor():
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
height, width = image_loader.image_size for size in _get_resize_sizes(image_loader.image_size):
for size in [ yield ArgsKwargs(
(height, width), image_loader,
(int(height * 0.75), int(width * 1.25)), size=size,
]: interpolation=interpolation,
yield ArgsKwargs(image_loader, size=size, interpolation=interpolation) antialias=interpolation
in {
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
},
)
def sample_inputs_resize_bounding_box(): def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders(formats=[features.BoundingBoxFormat.XYXY]):
height, width = bounding_box_loader.image_size for size in _get_resize_sizes(bounding_box_loader.image_size):
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size) yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size)
def sample_inputs_resize_mask():
for mask_loader in make_mask_loaders(dtypes=[torch.uint8]):
for size in _get_resize_sizes(mask_loader.shape[-2:]):
yield ArgsKwargs(mask_loader, size=size)
@pil_reference_wrapper
def reference_resize_mask(*args, **kwargs):
return F.resize_image_pil(*args, interpolation=F.InterpolationMode.NEAREST, **kwargs)
def reference_inputs_resize_mask():
for mask_loader in make_mask_loaders(extra_dims=[()], num_objects=[1]):
for size in _get_resize_sizes(mask_loader.shape[-2:]):
yield ArgsKwargs(mask_loader, size=size)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.resize_image_tensor, F.resize_image_tensor,
sample_inputs_fn=sample_inputs_resize_image_tensor, sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=pil_reference_wrapper(F.resize_image_pil), reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
...@@ -161,6 +198,13 @@ KERNEL_INFOS.extend( ...@@ -161,6 +198,13 @@ KERNEL_INFOS.extend(
F.resize_bounding_box, F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box, sample_inputs_fn=sample_inputs_resize_bounding_box,
), ),
KernelInfo(
F.resize_mask,
sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
] ]
) )
...@@ -193,9 +237,9 @@ def sample_inputs_affine_image_tensor(): ...@@ -193,9 +237,9 @@ def sample_inputs_affine_image_tensor():
def reference_inputs_affine_image_tensor(): def reference_inputs_affine_image_tensor():
for image, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS): for image_loader, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS):
yield ArgsKwargs( yield ArgsKwargs(
image, image_loader,
interpolation=F.InterpolationMode.NEAREST, interpolation=F.InterpolationMode.NEAREST,
**affine_kwargs, **affine_kwargs,
) )
...@@ -234,7 +278,7 @@ def _compute_affine_matrix(angle, translate, scale, shear, center): ...@@ -234,7 +278,7 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix return true_matrix
def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center): def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center=None):
if center is None: if center is None:
center = [s * 0.5 for s in image_size[::-1]] center = [s * 0.5 for s in image_size[::-1]]
...@@ -259,7 +303,8 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr ...@@ -259,7 +303,8 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr
np.max(transformed_points[:, 0]), np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
], ],
dtype=bbox.dtype, # FIXME: re-add this as soon as the kernel is fixed to also retain the dtype
# dtype=bbox.dtype,
) )
return F.convert_format_bounding_box( return F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
...@@ -278,26 +323,38 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr ...@@ -278,26 +323,38 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr
def reference_inputs_affine_bounding_box(): def reference_inputs_affine_bounding_box():
for bounding_box_loader, angle, translate, scale, shear, center in itertools.product( for bounding_box_loader, affine_kwargs in itertools.product(
make_bounding_box_loaders(extra_dims=[(4,)], image_size=(32, 38), dtypes=[torch.float32]), make_bounding_box_loaders(extra_dims=[()]),
range(-90, 90, 56), _AFFINE_KWARGS,
range(-10, 10, 8),
[0.77, 1.0, 1.27],
range(-15, 15, 8),
[None, (12, 14)],
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, image_size=bounding_box_loader.image_size,
angle=angle, **affine_kwargs,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
) )
def sample_inputs_affine_image_mask():
for mask_loader, center in itertools.product(
make_mask_loaders(dtypes=[torch.uint8]),
[None, (0, 0)],
):
yield ArgsKwargs(mask_loader, center=center, **_AFFINE_KWARGS[0])
@pil_reference_wrapper
def reference_affine_mask(*args, **kwargs):
return F.affine_image_pil(*args, interpolation=F.InterpolationMode.NEAREST, **kwargs)
def reference_inputs_resize_mask():
for mask_loader, affine_kwargs in itertools.product(
make_mask_loaders(extra_dims=[()], num_objects=[1]), _AFFINE_KWARGS
):
yield ArgsKwargs(mask_loader, **affine_kwargs)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -313,5 +370,84 @@ KERNEL_INFOS.extend( ...@@ -313,5 +370,84 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_bounding_box, reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box,
), ),
KernelInfo(
F.affine_mask,
sample_inputs_fn=sample_inputs_affine_image_mask,
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
] ]
) )
def sample_inputs_convert_format_bounding_box():
formats = set(features.BoundingBoxFormat)
for bounding_box_loader in make_bounding_box_loaders(formats=formats):
old_format = bounding_box_loader.format
for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)):
yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy):
if not copy:
raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`")
return torchvision.ops.box_convert(
bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower()
)
def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True):
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_format_bounding_box,
sample_inputs_fn=sample_inputs_convert_format_bounding_box,
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
),
)
def sample_inputs_convert_color_space_image_tensor():
color_spaces = set(features.ColorSpace) - {features.ColorSpace.OTHER}
for image_loader in make_image_loaders(sizes=["random"], color_spaces=color_spaces, constant_alpha=True):
old_color_space = image_loader.color_space
for params in combinations_grid(new_color_space=color_spaces - {old_color_space}, copy=(True, False)):
yield ArgsKwargs(image_loader, old_color_space=old_color_space, **params)
@pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy):
color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space:
raise pytest.UsageError(
f"Converting the tensor image into an PIL image changed the colorspace "
f"from {old_color_space} to {color_space_pil}"
)
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy)
def reference_inputs_convert_color_space_image_tensor():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 3:
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_color_space_image_tensor,
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
)
...@@ -66,40 +66,6 @@ def vertical_flip_mask(): ...@@ -66,40 +66,6 @@ def vertical_flip_mask():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn
def resize_mask():
for mask, max_size in itertools.product(
make_masks(),
[None, 34], # max_size
):
height, width = mask.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
if max_size is not None:
size = [size[0]]
yield ArgsKwargs(mask, size=size, max_size=max_size)
@register_kernel_info_from_sample_inputs_fn
def affine_mask():
for mask, angle, translate, scale, shear in itertools.product(
make_masks(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
[0, 12], # shear
):
yield ArgsKwargs(
mask,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_image_tensor(): def rotate_image_tensor():
for image, angle, expand, center in itertools.product( for image, angle, expand, center in itertools.product(
...@@ -507,6 +473,41 @@ def test_eager_vs_scripted(functional_info, sample_input): ...@@ -507,6 +473,41 @@ def test_eager_vs_scripted(functional_info, sample_input):
torch.testing.assert_close(eager, scripted) torch.testing.assert_close(eager, scripted)
@pytest.mark.parametrize(
("functional_info", "sample_input"),
[
pytest.param(
functional_info,
sample_input,
id=f"{functional_info.name}-{idx}",
marks=[
*(
[pytest.mark.xfail(strict=False)]
if functional_info.name
in {
"rotate_bounding_box",
"crop_bounding_box",
"resized_crop_bounding_box",
"perspective_bounding_box",
"elastic_bounding_box",
"center_crop_bounding_box",
}
else []
)
],
)
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_dtype_consistency(functional_info, sample_input):
(input, *other_args), kwargs = sample_input
output = functional_info.functional(input, *other_args, **kwargs)
assert output.dtype == input.dtype
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
rot = math.radians(angle_) rot = math.radians(angle_)
cx, cy = center_ cx, cy = center_
...@@ -530,82 +531,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -530,82 +531,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
return true_matrix return true_matrix
@pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("translate", range(-10, 10, 8))
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
@pytest.mark.parametrize("shear", range(-15, 15, 8))
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_affine_bounding_box(angle, translate, scale, shear, center):
def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
affine_matrix = affine_matrix[:2, :]
bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = [
np.min(transformed_points[:, 0]),
np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]),
]
out_bbox = features.BoundingBox(
out_bbox,
format=features.BoundingBoxFormat.XYXY,
image_size=bbox.image_size,
dtype=torch.float32,
device=bbox.device,
)
return convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
)
image_size = (32, 38)
for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)):
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
output_bboxes = F.affine_bounding_box(
bboxes,
bboxes_format,
image_size=bboxes_image_size,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
)
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]
if bboxes.ndim < 2:
bboxes = [bboxes]
expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_)
)
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_bounding_box_on_fixed_input(device): def test_correctness_affine_bounding_box_on_fixed_input(device):
# Check transformation against known expected output # Check transformation against known expected output
...@@ -655,60 +580,6 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -655,60 +580,6 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
@pytest.mark.parametrize("angle", [-54, 56])
@pytest.mark.parametrize("translate", [-7, 8])
@pytest.mark.parametrize("scale", [0.89, 1.12])
@pytest.mark.parametrize("shear", [4])
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_affine_mask(angle, translate, scale, shear, center):
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
assert mask.ndim == 3
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
inv_affine_matrix = np.linalg.inv(affine_matrix)
inv_affine_matrix = inv_affine_matrix[:2, :]
expected_mask = torch.zeros_like(mask.cpu())
for out_y in range(expected_mask.shape[1]):
for out_x in range(expected_mask.shape[2]):
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype("int")
in_x, in_y = input_pt[:2]
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
for i in range(expected_mask.shape[0]):
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device)
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.affine_mask(
mask,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
)
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]
if mask.ndim < 4:
masks = [mask]
else:
masks = [m for m in mask]
expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
else:
expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_segmentation_mask_on_fixed_input(device): def test_correctness_affine_segmentation_mask_on_fixed_input(device):
# Check transformation against known expected output and CPU/CUDA devices # Check transformation against known expected output and CPU/CUDA devices
......
...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F ...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F
def test_coverage(): def test_coverage():
tested = {info.kernel.__name__ for info in KERNEL_INFOS} tested = {info.kernel_name for info in KERNEL_INFOS}
exposed = { exposed = {
name name
for name, kernel in F.__dict__.items() for name, kernel in F.__dict__.items()
...@@ -36,14 +36,11 @@ def test_coverage(): ...@@ -36,14 +36,11 @@ def test_coverage():
"adjust_hue_image_tensor", "adjust_hue_image_tensor",
"adjust_saturation_image_tensor", "adjust_saturation_image_tensor",
"adjust_sharpness_image_tensor", "adjust_sharpness_image_tensor",
"affine_mask",
"autocontrast_image_tensor", "autocontrast_image_tensor",
"center_crop_bounding_box", "center_crop_bounding_box",
"center_crop_image_tensor", "center_crop_image_tensor",
"center_crop_mask", "center_crop_mask",
"clamp_bounding_box", "clamp_bounding_box",
"convert_color_space_image_tensor",
"convert_format_bounding_box",
"crop_bounding_box", "crop_bounding_box",
"crop_image_tensor", "crop_image_tensor",
"crop_mask", "crop_mask",
...@@ -54,7 +51,6 @@ def test_coverage(): ...@@ -54,7 +51,6 @@ def test_coverage():
"erase_image_tensor", "erase_image_tensor",
"five_crop_image_tensor", "five_crop_image_tensor",
"gaussian_blur_image_tensor", "gaussian_blur_image_tensor",
"horizontal_flip_image_tensor",
"invert_image_tensor", "invert_image_tensor",
"normalize_image_tensor", "normalize_image_tensor",
"pad_bounding_box", "pad_bounding_box",
...@@ -64,7 +60,6 @@ def test_coverage(): ...@@ -64,7 +60,6 @@ def test_coverage():
"perspective_image_tensor", "perspective_image_tensor",
"perspective_mask", "perspective_mask",
"posterize_image_tensor", "posterize_image_tensor",
"resize_mask",
"resized_crop_bounding_box", "resized_crop_bounding_box",
"resized_crop_image_tensor", "resized_crop_image_tensor",
"resized_crop_mask", "resized_crop_mask",
...@@ -79,6 +74,13 @@ def test_coverage(): ...@@ -79,6 +74,13 @@ def test_coverage():
} }
} }
needlessly_ignored = tested - exposed
if needlessly_ignored:
raise pytest.UsageError(
f"The kernel(s) {sequence_to_str(sorted(needlessly_ignored), separate_last='and ')} "
f"have an associated `KernelInfo` but are ignored by this test."
)
untested = exposed - tested untested = exposed - tested
if untested: if untested:
raise AssertionError( raise AssertionError(
...@@ -92,9 +94,9 @@ class TestCommon: ...@@ -92,9 +94,9 @@ class TestCommon:
sample_inputs = pytest.mark.parametrize( sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"), ("info", "args_kwargs"),
[ [
pytest.param(info, args_kwargs, id=f"{info.kernel.__name__}") pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS for info in KERNEL_INFOS
for args_kwargs in info.sample_inputs_fn() for idx, args_kwargs in enumerate(info.sample_inputs_fn())
], ],
) )
...@@ -182,14 +184,47 @@ class TestCommon: ...@@ -182,14 +184,47 @@ class TestCommon:
output_cpu = info.kernel(input_cpu, *other_args, **kwargs) output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs) output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False) assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs)
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(
info,
args_kwargs,
id=f"{info.kernel_name}-",
marks=[
*(
[pytest.mark.xfail(strict=False)]
if info.kernel_name
in {
"resize_bounding_box",
"affine_bounding_box",
"convert_format_bounding_box",
}
else []
)
],
)
for info in KERNEL_INFOS
for args_kwargs in info.sample_inputs_fn()
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
assert output.dtype == input.dtype
assert output.device == torch.device(device)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("info", "args_kwargs"), ("info", "args_kwargs"),
[ [
pytest.param(info, args_kwargs, id=f"{info.kernel.__name__}") pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS for info in KERNEL_INFOS
for args_kwargs in info.reference_inputs_fn() for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None if info.reference_fn is not None
], ],
) )
...@@ -199,4 +234,4 @@ class TestCommon: ...@@ -199,4 +234,4 @@ class TestCommon:
actual = info.kernel(*args, **kwargs) actual = info.kernel(*args, **kwargs)
expected = info.reference_fn(*args, **kwargs) expected = info.reference_fn(*args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs, check_dtype=False) assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment