Unverified Commit 46eae182 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use pytest markers instead of custom solution for prototype transforms functional tests (#6653)

* use pytest markers instead of custom solution for prototype transforms functional tests

* cleanup

* cleanup

* trigger CI
parent a46c4f0c
...@@ -2,12 +2,12 @@ import collections.abc ...@@ -2,12 +2,12 @@ import collections.abc
import dataclasses import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Type from typing import Callable, Dict, List, Optional, Sequence, Type
import pytest import pytest
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_common_utils import BoundingBoxLoader from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
from torchvision.prototype import features from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
...@@ -24,35 +24,27 @@ class PILKernelInfo: ...@@ -24,35 +24,27 @@ class PILKernelInfo:
self.kernel_name = self.kernel_name or self.kernel.__name__ self.kernel_name = self.kernel_name or self.kernel.__name__
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)
def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
@dataclasses.dataclass @dataclasses.dataclass
class DispatcherInfo: class DispatcherInfo:
dispatcher: Callable dispatcher: Callable
kernels: Dict[Type, Callable] kernels: Dict[Type, Callable]
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
pil_kernel_info: Optional[PILKernelInfo] = None pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None) method_name: str = dataclasses.field(default=None)
skips: Sequence[Skip] = dataclasses.field(default_factory=list) test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self): def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()} self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__ self.method_name = self.method_name or self.dispatcher.__name__
skips_map = defaultdict(list) test_marks_map = defaultdict(list)
for skip in self.skips: for test_mark in self.test_marks:
skips_map[skip.test_name].append(skip) test_marks_map[test_mark.test_id].append(test_mark)
self._skips_map = dict(skips_map) self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
def sample_inputs(self, *feature_types, filter_metadata=True): def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys(): for feature_type in feature_types or self.kernels.keys():
...@@ -70,17 +62,27 @@ class DispatcherInfo: ...@@ -70,17 +62,27 @@ class DispatcherInfo:
yield args_kwargs yield args_kwargs
def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
if not skips:
return
for skip in skips: def xfail_python_scalar_arg_jit(name, *, reason=None):
if skip.condition(args_kwargs, device): reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
pytest.skip(skip.reason) return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
)
def fill_sequence_needs_broadcast(args_kwargs, device): def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
skip_dispatch_feature = TestMark(
("TestDispatchers", "test_dispatch_feature"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."),
)
def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs (image_loader, *_), kwargs = args_kwargs
try: try:
fill = kwargs["fill"] fill = kwargs["fill"]
...@@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device): ...@@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
return image_loader.num_channels > 1 return image_loader.num_channels > 1
skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip( xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
"test_dispatch_pil", ("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
condition=fill_sequence_needs_broadcast, condition=fill_sequence_needs_broadcast,
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
)
skip_dispatch_feature = Skip(
"test_dispatch_feature",
reason="Dispatcher doesn't support arbitrary feature dispatch.",
) )
...@@ -123,8 +122,8 @@ DISPATCHER_INFOS = [ ...@@ -123,8 +122,8 @@ DISPATCHER_INFOS = [
features.Mask: F.resize_mask, features.Mask: F.resize_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resize_image_pil), pil_kernel_info=PILKernelInfo(F.resize_image_pil),
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -135,9 +134,9 @@ DISPATCHER_INFOS = [ ...@@ -135,9 +134,9 @@ DISPATCHER_INFOS = [
features.Mask: F.affine_mask, features.Mask: F.affine_mask,
}, },
pil_kernel_info=PILKernelInfo(F.affine_image_pil), pil_kernel_info=PILKernelInfo(F.affine_image_pil),
skips=[ test_marks=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast, xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"), xfail_python_scalar_arg_jit("shear"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -166,16 +165,6 @@ DISPATCHER_INFOS = [ ...@@ -166,16 +165,6 @@ DISPATCHER_INFOS = [
features.Mask: F.crop_mask, features.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
skips=[
Skip(
"test_dispatch_feature",
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
reason=(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
...@@ -193,10 +182,20 @@ DISPATCHER_INFOS = [ ...@@ -193,10 +182,20 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.pad_bounding_box, features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask, features.Mask: F.pad_mask,
}, },
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
)
],
), ),
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
...@@ -205,10 +204,10 @@ DISPATCHER_INFOS = [ ...@@ -205,10 +204,10 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.perspective_bounding_box, features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask, features.Mask: F.perspective_mask,
}, },
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
],
), ),
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
...@@ -227,8 +226,8 @@ DISPATCHER_INFOS = [ ...@@ -227,8 +226,8 @@ DISPATCHER_INFOS = [
features.Mask: F.center_crop_mask, features.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
skips=[ test_marks=[
skip_integer_size_jit("output_size"), xfail_integer_size_jit("output_size"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -237,9 +236,9 @@ DISPATCHER_INFOS = [ ...@@ -237,9 +236,9 @@ DISPATCHER_INFOS = [
features.Image: F.gaussian_blur_image_tensor, features.Image: F.gaussian_blur_image_tensor,
}, },
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
skips=[ test_marks=[
skip_python_scalar_arg_jit("kernel_size"), xfail_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"), xfail_python_scalar_arg_jit("sigma"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -290,7 +289,7 @@ DISPATCHER_INFOS = [ ...@@ -290,7 +289,7 @@ DISPATCHER_INFOS = [
features.Image: F.erase_image_tensor, features.Image: F.erase_image_tensor,
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F.erase_image_pil),
skips=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
], ],
), ),
...@@ -335,8 +334,8 @@ DISPATCHER_INFOS = [ ...@@ -335,8 +334,8 @@ DISPATCHER_INFOS = [
features.Image: F.five_crop_image_tensor, features.Image: F.five_crop_image_tensor,
}, },
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
skip_dispatch_feature, skip_dispatch_feature,
], ],
), ),
...@@ -345,18 +344,18 @@ DISPATCHER_INFOS = [ ...@@ -345,18 +344,18 @@ DISPATCHER_INFOS = [
kernels={ kernels={
features.Image: F.ten_crop_image_tensor, features.Image: F.ten_crop_image_tensor,
}, },
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), test_marks=[
skips=[ xfail_integer_size_jit(),
skip_integer_size_jit(),
skip_dispatch_feature, skip_dispatch_feature,
], ],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
features.Image: F.normalize_image_tensor, features.Image: F.normalize_image_tensor,
}, },
skips=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
], ],
), ),
......
...@@ -3,13 +3,15 @@ import functools ...@@ -3,13 +3,15 @@ import functools
import itertools import itertools
import math import math
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from _pytest.mark.structures import MarkDecorator
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
from torchvision.prototype import features from torchvision.prototype import features
...@@ -18,11 +20,14 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value ...@@ -18,11 +20,14 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
TestID = Tuple[Optional[str], str]
@dataclasses.dataclass @dataclasses.dataclass
class Skip: class TestMark:
test_name: str test_id: TestID
reason: str mark: MarkDecorator
condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True
@dataclasses.dataclass @dataclasses.dataclass
...@@ -44,26 +49,22 @@ class KernelInfo: ...@@ -44,26 +49,22 @@ class KernelInfo:
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
skips: Sequence[Skip] = dataclasses.field(default_factory=list) test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self): def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__ self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
skips_map = defaultdict(list) test_marks_map = defaultdict(list)
for skip in self.skips: for test_mark in self.test_marks:
skips_map[skip.test_name].append(skip) test_marks_map[test_mark.test_id].append(test_mark)
self._skips_map = dict(skips_map) self._test_marks_map = dict(test_marks_map)
def maybe_skip(self, *, test_name, args_kwargs, device): def get_marks(self, test_id, args_kwargs):
skips = self._skips_map.get(test_name) return [
if not skips: test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
return ]
for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
...@@ -87,16 +88,27 @@ def pil_reference_wrapper(pil_kernel): ...@@ -87,16 +88,27 @@ def pil_reference_wrapper(pil_kernel):
return wrapper return wrapper
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"): def mark_framework_limitation(test_id, reason):
return Skip( # The purpose of this function is to have a single entry point for skip marks that are only there, because the test
"test_scripted_vs_eager", # framework cannot handle the kernel in general or a specific parameter combination.
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)), # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
reason=reason, # still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time.
return TestMark(test_id, pytest.mark.skip(reason=reason))
def xfail_python_scalar_arg_jit(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
) )
def skip_integer_size_jit(name="size"): def xfail_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.") return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
KERNEL_INFOS = [] KERNEL_INFOS = []
...@@ -151,8 +163,7 @@ KERNEL_INFOS.extend( ...@@ -151,8 +163,7 @@ KERNEL_INFOS.extend(
def _get_resize_sizes(image_size): def _get_resize_sizes(image_size):
height, width = image_size height, width = image_size
length = max(image_size) length = max(image_size)
# FIXME: enable me when the kernels are fixed yield length
# yield length
yield [length] yield [length]
yield (length,) yield (length,)
new_height = int(height * 0.75) new_height = int(height * 0.75)
...@@ -236,15 +247,15 @@ KERNEL_INFOS.extend( ...@@ -236,15 +247,15 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_image_tensor, reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
], ],
), ),
KernelInfo( KernelInfo(
F.resize_bounding_box, F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box, sample_inputs_fn=sample_inputs_resize_bounding_box,
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -253,8 +264,8 @@ KERNEL_INFOS.extend( ...@@ -253,8 +264,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_mask, reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
], ],
), ),
] ]
...@@ -436,16 +447,6 @@ def reference_inputs_resize_mask(): ...@@ -436,16 +447,6 @@ def reference_inputs_resize_mask():
yield ArgsKwargs(mask_loader, **affine_kwargs) yield ArgsKwargs(mask_loader, **affine_kwargs)
# FIXME: @datumbox, remove this as soon as you have fixed the behavior in https://github.com/pytorch/vision/pull/6636
def skip_scalar_shears(*test_names):
for test_name in test_names:
yield Skip(
test_name,
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["shear"], (int, float)),
reason="The kernel is broken for a scalar `shear`",
)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -454,7 +455,7 @@ KERNEL_INFOS.extend( ...@@ -454,7 +455,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.affine_image_pil), reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor, reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], test_marks=[xfail_python_scalar_arg_jit("shear")],
), ),
KernelInfo( KernelInfo(
F.affine_bounding_box, F.affine_bounding_box,
...@@ -462,13 +463,8 @@ KERNEL_INFOS.extend( ...@@ -462,13 +463,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_bounding_box, reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0), closeness_kwargs=dict(atol=1, rtol=0),
skips=[ test_marks=[
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"), xfail_python_scalar_arg_jit("shear"),
*skip_scalar_shears(
"test_batched_vs_single",
"test_no_inplace",
"test_dtype_and_device_consistency",
),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -477,7 +473,7 @@ KERNEL_INFOS.extend( ...@@ -477,7 +473,7 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_mask, reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], test_marks=[xfail_python_scalar_arg_jit("shear")],
), ),
] ]
) )
...@@ -1093,15 +1089,15 @@ KERNEL_INFOS.extend( ...@@ -1093,15 +1089,15 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor, reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[ test_marks=[
skip_integer_size_jit("output_size"), xfail_integer_size_jit("output_size"),
], ],
), ),
KernelInfo( KernelInfo(
F.center_crop_bounding_box, F.center_crop_bounding_box,
sample_inputs_fn=sample_inputs_center_crop_bounding_box, sample_inputs_fn=sample_inputs_center_crop_bounding_box,
skips=[ test_marks=[
skip_integer_size_jit("output_size"), xfail_integer_size_jit("output_size"),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -1110,8 +1106,8 @@ KERNEL_INFOS.extend( ...@@ -1110,8 +1106,8 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask, reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[ test_marks=[
skip_integer_size_jit("output_size"), xfail_integer_size_jit("output_size"),
], ],
), ),
] ]
...@@ -1138,9 +1134,9 @@ KERNEL_INFOS.append( ...@@ -1138,9 +1134,9 @@ KERNEL_INFOS.append(
F.gaussian_blur_image_tensor, F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[ test_marks=[
skip_python_scalar_arg_jit("kernel_size"), xfail_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"), xfail_python_scalar_arg_jit("sigma"),
], ],
) )
) )
...@@ -1551,9 +1547,9 @@ KERNEL_INFOS.extend( ...@@ -1551,9 +1547,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_five_crop_image_tensor, sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil), reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor, reference_inputs_fn=reference_inputs_five_crop_image_tensor,
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
], ],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
...@@ -1562,9 +1558,9 @@ KERNEL_INFOS.extend( ...@@ -1562,9 +1558,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_ten_crop_image_tensor, sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
skips=[ test_marks=[
skip_integer_size_jit(), xfail_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
], ],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
......
import functools
import math import math
import os import os
...@@ -26,33 +27,60 @@ def script(fn): ...@@ -26,33 +27,60 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
@pytest.fixture(autouse=True) def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)):
def maybe_skip(request): if condition is None:
# In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist
try:
callspec = request.node.callspec
except AttributeError:
return
try: def condition(info):
info = callspec.params["info"] return True
args_kwargs = callspec.params["args_kwargs"]
except KeyError:
return
info.maybe_skip( def decorator(test_fn):
test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu") parts = test_fn.__qualname__.split(".")
if len(parts) == 1:
test_class_name = None
test_function_name = parts[0]
elif len(parts) == 2:
test_class_name, test_function_name = parts
else:
raise pytest.UsageError("Unable to parse the test class and test name from test function")
test_id = (test_class_name, test_function_name)
argnames = ("info", "args_kwargs")
argvalues = []
for info in infos:
if not condition(info):
continue
args_kwargs = list(args_kwargs_fn(info))
name = name_fn(info)
idx_field_len = len(str(len(args_kwargs)))
for idx, args_kwargs_ in enumerate(args_kwargs):
argvalues.append(
pytest.param(
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_),
id=f"{name}-{idx:0{idx_field_len}}",
)
) )
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
return decorator
class TestKernels: class TestKernels:
sample_inputs = pytest.mark.parametrize( make_kernel_args_kwargs_parametrization = functools.partial(
("info", "args_kwargs"), make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name
[ )
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}") sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization(
for info in KERNEL_INFOS KERNEL_INFOS,
for idx, args_kwargs in enumerate(info.sample_inputs_fn()) args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
], )
reference_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
) )
@sample_inputs @sample_inputs
...@@ -156,15 +184,7 @@ class TestKernels: ...@@ -156,15 +184,7 @@ class TestKernels:
assert output.dtype == input.dtype assert output.dtype == input.dtype
assert output.device == input.device assert output.device == input.device
@pytest.mark.parametrize( @reference_inputs
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None
],
)
def test_against_reference(self, info, args_kwargs): def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu") args, kwargs = args_kwargs.load("cpu")
...@@ -187,15 +207,16 @@ def spy_on(mocker): ...@@ -187,15 +207,16 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
@pytest.mark.parametrize( make_dispatcher_args_kwargs_parametrization = functools.partial(
("info", "args_kwargs"), make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__
[ )
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization(
for info in DISPATCHER_INFOS DISPATCHER_INFOS,
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image)) args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
if features.Image in info.kernels condition=lambda info: features.Image in info.kernels,
],
) )
@image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_smoke(self, info, args_kwargs, device): def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher) dispatcher = script(info.dispatcher)
...@@ -223,15 +244,7 @@ class TestDispatchers: ...@@ -223,15 +244,7 @@ class TestDispatchers:
def test_scriptable(self, dispatcher): def test_scriptable(self, dispatcher):
script(dispatcher) script(dispatcher)
@pytest.mark.parametrize( @image_sample_inputs
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels
],
)
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_feature)
...@@ -243,14 +256,10 @@ class TestDispatchers: ...@@ -243,14 +256,10 @@ class TestDispatchers:
spy.assert_called_once() spy.assert_called_once()
@pytest.mark.parametrize( @make_dispatcher_args_kwargs_parametrization(
("info", "args_kwargs"), DISPATCHER_INFOS,
[ args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") condition=lambda info: info.pil_kernel_info is not None,
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels and info.pil_kernel_info is not None
],
) )
def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
...@@ -267,13 +276,9 @@ class TestDispatchers: ...@@ -267,13 +276,9 @@ class TestDispatchers:
spy.assert_called_once() spy.assert_called_once()
@pytest.mark.parametrize( @make_dispatcher_args_kwargs_parametrization(
("info", "args_kwargs"), DISPATCHER_INFOS,
[ args_kwargs_fn=lambda info: info.sample_inputs(),
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs())
],
) )
def test_dispatch_feature(self, info, args_kwargs, spy_on): def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load() (feature, *other_args), kwargs = args_kwargs.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment