Unverified Commit 46eae182 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use pytest markers instead of custom solution for prototype transforms functional tests (#6653)

* use pytest markers instead of custom solution for prototype transforms functional tests

* cleanup

* cleanup

* trigger CI
parent a46c4f0c
......@@ -2,12 +2,12 @@ import collections.abc
import dataclasses
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Type
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import BoundingBoxLoader
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......@@ -24,35 +24,27 @@ class PILKernelInfo:
self.kernel_name = self.kernel_name or self.kernel.__name__
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)
def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
......@@ -70,17 +62,27 @@ class DispatcherInfo:
yield args_kwargs
def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
if not skips:
return
for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
def xfail_python_scalar_arg_jit(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
)
def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
def fill_sequence_needs_broadcast(args_kwargs, device):
skip_dispatch_feature = TestMark(
("TestDispatchers", "test_dispatch_feature"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."),
)
def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
......@@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
return image_loader.num_channels > 1
skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip(
"test_dispatch_pil",
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
condition=fill_sequence_needs_broadcast,
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
)
skip_dispatch_feature = Skip(
"test_dispatch_feature",
reason="Dispatcher doesn't support arbitrary feature dispatch.",
)
......@@ -123,8 +122,8 @@ DISPATCHER_INFOS = [
features.Mask: F.resize_mask,
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
],
),
DispatcherInfo(
......@@ -135,9 +134,9 @@ DISPATCHER_INFOS = [
features.Mask: F.affine_mask,
},
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
xfail_python_scalar_arg_jit("shear"),
],
),
DispatcherInfo(
......@@ -166,16 +165,6 @@ DISPATCHER_INFOS = [
features.Mask: F.crop_mask,
},
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
skips=[
Skip(
"test_dispatch_feature",
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
reason=(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
),
DispatcherInfo(
F.resized_crop,
......@@ -193,10 +182,20 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
)
],
),
DispatcherInfo(
F.perspective,
......@@ -205,10 +204,10 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
],
),
DispatcherInfo(
F.elastic,
......@@ -227,8 +226,8 @@ DISPATCHER_INFOS = [
features.Mask: F.center_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
skips=[
skip_integer_size_jit("output_size"),
test_marks=[
xfail_integer_size_jit("output_size"),
],
),
DispatcherInfo(
......@@ -237,9 +236,9 @@ DISPATCHER_INFOS = [
features.Image: F.gaussian_blur_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
test_marks=[
xfail_python_scalar_arg_jit("kernel_size"),
xfail_python_scalar_arg_jit("sigma"),
],
),
DispatcherInfo(
......@@ -290,7 +289,7 @@ DISPATCHER_INFOS = [
features.Image: F.erase_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
skips=[
test_marks=[
skip_dispatch_feature,
],
),
......@@ -335,8 +334,8 @@ DISPATCHER_INFOS = [
features.Image: F.five_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
skip_dispatch_feature,
],
),
......@@ -345,18 +344,18 @@ DISPATCHER_INFOS = [
kernels={
features.Image: F.ten_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
skip_dispatch_feature,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
DispatcherInfo(
F.normalize,
kernels={
features.Image: F.normalize_image_tensor,
},
skips=[
test_marks=[
skip_dispatch_feature,
],
),
......
......@@ -3,13 +3,15 @@ import functools
import itertools
import math
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F
from _pytest.mark.structures import MarkDecorator
from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
from torchvision.prototype import features
......@@ -18,11 +20,14 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = ["KernelInfo", "KERNEL_INFOS"]
TestID = Tuple[Optional[str], str]
@dataclasses.dataclass
class Skip:
test_name: str
reason: str
condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True
class TestMark:
test_id: TestID
mark: MarkDecorator
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True
@dataclasses.dataclass
......@@ -44,26 +49,22 @@ class KernelInfo:
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
if not skips:
return
for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -87,16 +88,27 @@ def pil_reference_wrapper(pil_kernel):
return wrapper
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time.
return TestMark(test_id, pytest.mark.skip(reason=reason))
def xfail_python_scalar_arg_jit(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
)
def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
KERNEL_INFOS = []
......@@ -151,8 +163,7 @@ KERNEL_INFOS.extend(
def _get_resize_sizes(image_size):
height, width = image_size
length = max(image_size)
# FIXME: enable me when the kernels are fixed
# yield length
yield length
yield [length]
yield (length,)
new_height = int(height * 0.75)
......@@ -236,15 +247,15 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
],
),
KernelInfo(
F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box,
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
],
),
KernelInfo(
......@@ -253,8 +264,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
],
),
]
......@@ -436,16 +447,6 @@ def reference_inputs_resize_mask():
yield ArgsKwargs(mask_loader, **affine_kwargs)
# FIXME: @datumbox, remove this as soon as you have fixed the behavior in https://github.com/pytorch/vision/pull/6636
def skip_scalar_shears(*test_names):
for test_name in test_names:
yield Skip(
test_name,
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["shear"], (int, float)),
reason="The kernel is broken for a scalar `shear`",
)
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -454,7 +455,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
test_marks=[xfail_python_scalar_arg_jit("shear")],
),
KernelInfo(
F.affine_bounding_box,
......@@ -462,13 +463,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0),
skips=[
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
*skip_scalar_shears(
"test_batched_vs_single",
"test_no_inplace",
"test_dtype_and_device_consistency",
),
test_marks=[
xfail_python_scalar_arg_jit("shear"),
],
),
KernelInfo(
......@@ -477,7 +473,7 @@ KERNEL_INFOS.extend(
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
test_marks=[xfail_python_scalar_arg_jit("shear")],
),
]
)
......@@ -1093,15 +1089,15 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit("output_size"),
test_marks=[
xfail_integer_size_jit("output_size"),
],
),
KernelInfo(
F.center_crop_bounding_box,
sample_inputs_fn=sample_inputs_center_crop_bounding_box,
skips=[
skip_integer_size_jit("output_size"),
test_marks=[
xfail_integer_size_jit("output_size"),
],
),
KernelInfo(
......@@ -1110,8 +1106,8 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_integer_size_jit("output_size"),
test_marks=[
xfail_integer_size_jit("output_size"),
],
),
]
......@@ -1138,9 +1134,9 @@ KERNEL_INFOS.append(
F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
test_marks=[
xfail_python_scalar_arg_jit("kernel_size"),
xfail_python_scalar_arg_jit("sigma"),
],
)
)
......@@ -1551,9 +1547,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
skips=[
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
test_marks=[
xfail_integer_size_jit(),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
......@@ -1562,9 +1558,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
skips=[
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
test_marks=[
xfail_integer_size_jit(),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
......
import functools
import math
import os
......@@ -26,33 +27,60 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
@pytest.fixture(autouse=True)
def maybe_skip(request):
# In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist
try:
callspec = request.node.callspec
except AttributeError:
return
def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)):
if condition is None:
try:
info = callspec.params["info"]
args_kwargs = callspec.params["args_kwargs"]
except KeyError:
return
def condition(info):
return True
info.maybe_skip(
test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu")
)
def decorator(test_fn):
parts = test_fn.__qualname__.split(".")
if len(parts) == 1:
test_class_name = None
test_function_name = parts[0]
elif len(parts) == 2:
test_class_name, test_function_name = parts
else:
raise pytest.UsageError("Unable to parse the test class and test name from test function")
test_id = (test_class_name, test_function_name)
argnames = ("info", "args_kwargs")
argvalues = []
for info in infos:
if not condition(info):
continue
args_kwargs = list(args_kwargs_fn(info))
name = name_fn(info)
idx_field_len = len(str(len(args_kwargs)))
for idx, args_kwargs_ in enumerate(args_kwargs):
argvalues.append(
pytest.param(
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_),
id=f"{name}-{idx:0{idx_field_len}}",
)
)
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
return decorator
class TestKernels:
sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs_fn())
],
make_kernel_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name
)
sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
)
reference_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
)
@sample_inputs
......@@ -156,15 +184,7 @@ class TestKernels:
assert output.dtype == input.dtype
assert output.device == input.device
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None
],
)
@reference_inputs
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")
......@@ -187,15 +207,16 @@ def spy_on(mocker):
class TestDispatchers:
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels
],
make_dispatcher_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__
)
image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels,
)
@image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)
......@@ -223,15 +244,7 @@ class TestDispatchers:
def test_scriptable(self, dispatcher):
script(dispatcher)
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels
],
)
@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature)
......@@ -243,14 +256,10 @@ class TestDispatchers:
spy.assert_called_once()
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels and info.pil_kernel_info is not None
],
@make_dispatcher_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: info.pil_kernel_info is not None,
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
......@@ -267,13 +276,9 @@ class TestDispatchers:
spy.assert_called_once()
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs())
],
@make_dispatcher_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment