Unverified Commit 3f1d9f6b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor `KernelInfo` and `DispatcherInfo` (#6710)

* make args and kwargs in ArgsKwargs more accessible

* refactor KernelInfo and DispatcherInfo

* remove ArgsKwargs __getitem__ shortcut again
parent 17969eba
......@@ -3,6 +3,7 @@
import collections.abc
import dataclasses
import functools
from collections import defaultdict
from typing import Callable, Optional, Sequence, Tuple, Union
import PIL.Image
......@@ -47,6 +48,9 @@ __all__ = [
"make_masks",
"make_video",
"make_videos",
"TestMark",
"mark_framework_limitation",
"InfoBase",
]
......@@ -588,3 +592,52 @@ def make_video_loaders(
make_videos = from_loaders(make_video_loaders)
class TestMark:
def __init__(
self,
# Tuple of test class name and test function name that identifies the test the mark is applied to. If there is
# no test class, i.e. a standalone test function, use `None`.
test_id,
# `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail`
mark,
*,
# Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be
# applied. If omitted, defaults to always apply.
condition=None,
):
self.test_id = test_id
self.mark = mark
self.condition = condition or (lambda args_kwargs: True)
def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))
class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
# Identifier if the info that shows up the parametrization.
self.id = id
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
self.test_marks = test_marks or []
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
self.closeness_kwargs = closeness_kwargs or dict()
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
import collections.abc
import dataclasses
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Type
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
@dataclasses.dataclass
class PILKernelInfo:
kernel: Callable
kernel_name: str = dataclasses.field(default=None)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
class PILKernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
):
super().__init__(id=kernel_name or kernel.__name__)
self.kernel = kernel
class DispatcherInfo(InfoBase):
_KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
def __init__(
self,
dispatcher,
*,
# Dictionary of types that map to the kernel the dispatcher dispatches to.
kernels,
# If omitted, no PIL dispatch test will be performed.
pil_kernel_info=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.dispatcher = dispatcher
self.kernels = kernels
self.pil_kernel_info = pil_kernel_info
kernel_infos = {}
for feature_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
kernel_infos[feature_type] = kernel_info
self.kernel_infos = kernel_infos
def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
if feature_type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}")
for feature_type in feature_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(feature_type)
if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
sample_inputs = kernel_info.sample_inputs_fn()
sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
if not filter_metadata:
yield from sample_inputs
else:
......
import dataclasses
import functools
import itertools
import math
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F
from _pytest.mark.structures import MarkDecorator
from common_utils import cycle_over
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
InfoBase,
make_bounding_box_loaders,
make_image_loader,
make_image_loaders,
make_mask_loaders,
make_video_loaders,
mark_framework_limitation,
TestMark,
VALID_EXTRA_DIMS,
)
from torchvision.prototype import features
......@@ -29,51 +27,35 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = ["KernelInfo", "KERNEL_INFOS"]
TestID = Tuple[Optional[str], str]
@dataclasses.dataclass
class TestMark:
test_id: TestID
mark: MarkDecorator
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True
@dataclasses.dataclass
class KernelInfo:
kernel: Callable
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
# not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name: str = dataclasses.field(default=None)
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
# tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn: Optional[Callable] = None
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
class KernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
# should not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn,
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
# take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
# happen inside the function. It should return a tensor or to be more precise an object that can be compared to
# a tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn=None,
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.kernel = kernel
self.sample_inputs_fn = sample_inputs_fn
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -97,16 +79,6 @@ def pil_reference_wrapper(pil_kernel):
return wrapper
def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time.
return TestMark(test_id, pytest.mark.skip(reason=reason))
def xfail_jit_python_scalar_arg(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
......
import functools
import math
import os
......@@ -27,7 +26,7 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)):
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
if condition is None:
def condition(info):
......@@ -41,7 +40,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
elif len(parts) == 2:
test_class_name, test_function_name = parts
else:
raise pytest.UsageError("Unable to parse the test class and test name from test function")
raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
test_id = (test_class_name, test_function_name)
argnames = ("info", "args_kwargs")
......@@ -51,7 +50,6 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
continue
args_kwargs = list(args_kwargs_fn(info))
name = name_fn(info)
idx_field_len = len(str(len(args_kwargs)))
for idx, args_kwargs_ in enumerate(args_kwargs):
......@@ -60,7 +58,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_),
id=f"{name}-{idx:0{idx_field_len}}",
id=f"{info.id}-{idx:0{idx_field_len}}",
)
)
......@@ -70,14 +68,11 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
class TestKernels:
make_kernel_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name
)
sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization(
sample_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
)
reference_inputs = make_kernel_args_kwargs_parametrization(
reference_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
......@@ -208,10 +203,7 @@ def spy_on(mocker):
class TestDispatchers:
make_dispatcher_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__
)
image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization(
image_sample_inputs = make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels,
......@@ -251,13 +243,13 @@ class TestDispatchers:
image_simple_tensor = torch.Tensor(image_feature)
kernel_info = info.kernel_infos[features.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.kernel_name)
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once()
@make_dispatcher_args_kwargs_parametrization(
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: info.pil_kernel_info is not None,
......@@ -271,22 +263,23 @@ class TestDispatchers:
image_pil = F.to_image_pil(image_feature)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.kernel_name)
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once()
@make_dispatcher_args_kwargs_parametrization(
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load()
method = getattr(feature, info.method_name)
method_name = info.id
method = getattr(feature, method_name)
feature_type = type(feature)
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{info.method_name}")
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}")
info.dispatcher(feature, *other_args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment