Unverified Commit 29b0831c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

diversify parameter types for a couple of prototype kernels (#6635)

* add more size types for prototype resize sample inputs

* add skip for dispatcher

* add more sizes to resize kernel info

* add more skips

* add more diversity to gaussian_blur parameters

* diversify affine parameters and fix bounding box kernel

* fix center_crop dispatcher info

* revert kernel fixes

* add skips for scalar shears in affine_bounding_box
parent 2d927283
import dataclasses
from typing import Callable, Dict, Sequence, Type
from collections import defaultdict
from typing import Callable, Dict, List, Sequence, Type
import pytest
import torchvision.prototype.transforms.functional as F
......@@ -11,15 +12,30 @@ __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)
def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self._skips_map = {skip.test_name: skip for skip in self.skips}
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
def sample_inputs(self, *types):
for type in types or self.kernels.keys():
......@@ -29,9 +45,13 @@ class DispatcherInfo:
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
skips = self._skips_map.get(test_name)
if not skips:
return
for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
DISPATCHER_INFOS = [
......@@ -50,6 +70,9 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask,
},
skips=[
skip_integer_size_jit(),
],
),
DispatcherInfo(
F.affine,
......@@ -58,6 +81,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask,
},
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
),
DispatcherInfo(
F.vertical_flip,
......@@ -122,12 +146,19 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask,
},
skips=[
skip_integer_size_jit("output_size"),
],
),
DispatcherInfo(
F.gaussian_blur,
kernels={
features.Image: F.gaussian_blur_image_tensor,
},
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
],
),
DispatcherInfo(
F.equalize,
......@@ -207,11 +238,7 @@ DISPATCHER_INFOS = [
features.Image: F.five_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
skip_integer_size_jit(),
],
),
DispatcherInfo(
......@@ -220,11 +247,7 @@ DISPATCHER_INFOS = [
features.Image: F.ten_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
skip_integer_size_jit(),
],
),
DispatcherInfo(
......
......@@ -2,7 +2,8 @@ import dataclasses
import functools
import itertools
import math
from typing import Any, Callable, Dict, Iterable, Optional, Sequence
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
import numpy as np
import pytest
......@@ -44,17 +45,25 @@ class KernelInfo:
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
self._skips_map = {skip.test_name: skip for skip in self.skips}
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
skips = self._skips_map.get(test_name)
if not skips:
return
for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -78,6 +87,18 @@ def pil_reference_wrapper(pil_kernel):
return wrapper
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)
def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
KERNEL_INFOS = []
......@@ -129,8 +150,15 @@ KERNEL_INFOS.extend(
def _get_resize_sizes(image_size):
height, width = image_size
length = max(image_size)
# FIXME: enable me when the kernels are fixed
# yield length
yield [length]
yield (length,)
new_height = int(height * 0.75)
new_width = int(width * 1.25)
yield [new_height, new_width]
yield height, width
yield int(height * 0.75), int(width * 1.25)
def sample_inputs_resize_image_tensor():
......@@ -208,10 +236,16 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit(),
],
),
KernelInfo(
F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box,
skips=[
skip_integer_size_jit(),
],
),
KernelInfo(
F.resize_mask,
......@@ -219,6 +253,9 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit(),
],
),
]
)
......@@ -232,6 +269,16 @@ _AFFINE_KWARGS = combinations_grid(
)
def _diversify_affine_kwargs_types(affine_kwargs):
angle = affine_kwargs["angle"]
for diverse_angle in [int(angle), float(angle)]:
yield dict(affine_kwargs, angle=diverse_angle)
shear = affine_kwargs["shear"]
for diverse_shear in [tuple(shear), list(shear), int(shear[0]), float(shear[0])]:
yield dict(affine_kwargs, shear=diverse_shear)
def sample_inputs_affine_image_tensor():
for image_loader, interpolation_mode, center in itertools.product(
make_image_loaders(sizes=["random"], dtypes=[torch.float32]),
......@@ -250,6 +297,11 @@ def sample_inputs_affine_image_tensor():
**_AFFINE_KWARGS[0],
)
for image_loader, affine_kwargs in itertools.product(
make_image_loaders(sizes=["random"], dtypes=[torch.float32]), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0])
):
yield ArgsKwargs(image_loader, **affine_kwargs)
def reference_inputs_affine_image_tensor():
for image_loader, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS):
......@@ -269,6 +321,16 @@ def sample_inputs_affine_bounding_box():
**_AFFINE_KWARGS[0],
)
for bounding_box_loader, affine_kwargs in itertools.product(
make_bounding_box_loaders(), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0])
):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size,
**affine_kwargs,
)
def _compute_affine_matrix(angle, translate, scale, shear, center):
rot = math.radians(angle)
......@@ -356,6 +418,11 @@ def sample_inputs_affine_image_mask():
):
yield ArgsKwargs(mask_loader, center=center, **_AFFINE_KWARGS[0])
for mask_loader, affine_kwargs in itertools.product(
make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0])
):
yield ArgsKwargs(mask_loader, **affine_kwargs)
@pil_reference_wrapper
def reference_affine_mask(*args, **kwargs):
......@@ -369,6 +436,16 @@ def reference_inputs_resize_mask():
yield ArgsKwargs(mask_loader, **affine_kwargs)
# FIXME: @datumbox, remove this as soon as you have fixed the behavior in https://github.com/pytorch/vision/pull/6636
def skip_scalar_shears(*test_names):
for test_name in test_names:
yield Skip(
test_name,
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["shear"], (int, float)),
reason="The kernel is broken for a scalar `shear`",
)
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -377,6 +454,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
),
KernelInfo(
F.affine_bounding_box,
......@@ -384,6 +462,14 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0),
skips=[
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
*skip_scalar_shears(
"test_batched_vs_single",
"test_no_inplace",
"test_dtype_and_device_consistency",
),
],
),
KernelInfo(
F.affine_mask,
......@@ -391,6 +477,7 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
),
]
)
......@@ -955,7 +1042,7 @@ KERNEL_INFOS.extend(
_CENTER_CROP_IMAGE_SIZES = [(16, 16), (7, 33), (31, 9)]
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4]]
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def sample_inputs_center_crop_image_tensor():
......@@ -1004,10 +1091,16 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit("output_size"),
],
),
KernelInfo(
F.center_crop_bounding_box,
sample_inputs_fn=sample_inputs_center_crop_bounding_box,
skips=[
skip_integer_size_jit("output_size"),
],
),
KernelInfo(
F.center_crop_mask,
......@@ -1015,6 +1108,9 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit("output_size"),
],
),
]
)
......@@ -1028,8 +1124,8 @@ def sample_inputs_gaussian_blur_image_tensor():
extra_dims=[(), (4,)],
),
combinations_grid(
kernel_size=[(3, 3)],
sigma=[None, (3.0, 3.0)],
kernel_size=[(3, 3), [3, 3], 5],
sigma=[None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)],
),
):
yield ArgsKwargs(image_loader, **params)
......@@ -1040,6 +1136,10 @@ KERNEL_INFOS.append(
F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
],
)
)
......@@ -1450,11 +1550,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
......@@ -1467,11 +1563,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment