Unverified Commit 3f1d9f6b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor `KernelInfo` and `DispatcherInfo` (#6710)

* make args and kwargs in ArgsKwargs more accessible

* refactor KernelInfo and DispatcherInfo

* remove ArgsKwargs __getitem__ shortcut again
parent 17969eba
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import collections.abc import collections.abc
import dataclasses import dataclasses
import functools import functools
from collections import defaultdict
from typing import Callable, Optional, Sequence, Tuple, Union from typing import Callable, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
...@@ -47,6 +48,9 @@ __all__ = [ ...@@ -47,6 +48,9 @@ __all__ = [
"make_masks", "make_masks",
"make_video", "make_video",
"make_videos", "make_videos",
"TestMark",
"mark_framework_limitation",
"InfoBase",
] ]
...@@ -588,3 +592,52 @@ def make_video_loaders( ...@@ -588,3 +592,52 @@ def make_video_loaders(
make_videos = from_loaders(make_video_loaders) make_videos = from_loaders(make_video_loaders)
class TestMark:
def __init__(
self,
# Tuple of test class name and test function name that identifies the test the mark is applied to. If there is
# no test class, i.e. a standalone test function, use `None`.
test_id,
# `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail`
mark,
*,
# Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be
# applied. If omitted, defaults to always apply.
condition=None,
):
self.test_id = test_id
self.mark = mark
self.condition = condition or (lambda args_kwargs: True)
def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))
class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
# Identifier if the info that shows up the parametrization.
self.id = id
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
self.test_marks = test_marks or []
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
self.closeness_kwargs = closeness_kwargs or dict()
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
import collections.abc import collections.abc
import dataclasses
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Type
import pytest import pytest
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torchvision.prototype import features from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
@dataclasses.dataclass
class PILKernelInfo:
kernel: Callable
kernel_name: str = dataclasses.field(default=None)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
class PILKernelInfo(InfoBase):
@dataclasses.dataclass def __init__(
class DispatcherInfo: self,
dispatcher: Callable kernel,
kernels: Dict[Type, Callable] *,
pil_kernel_info: Optional[PILKernelInfo] = None # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
method_name: str = dataclasses.field(default=None) # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) kernel_name=None,
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) ):
super().__init__(id=kernel_name or kernel.__name__)
def __post_init__(self): self.kernel = kernel
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
test_marks_map = defaultdict(list) class DispatcherInfo(InfoBase):
for test_mark in self.test_marks: _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map) def __init__(
self,
def get_marks(self, test_id, args_kwargs): dispatcher,
return [ *,
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) # Dictionary of types that map to the kernel the dispatcher dispatches to.
] kernels,
# If omitted, no PIL dispatch test will be performed.
pil_kernel_info=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.dispatcher = dispatcher
self.kernels = kernels
self.pil_kernel_info = pil_kernel_info
kernel_infos = {}
for feature_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
kernel_infos[feature_type] = kernel_info
self.kernel_infos = kernel_infos
def sample_inputs(self, *feature_types, filter_metadata=True): def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys(): for feature_type in feature_types or self.kernel_infos.keys():
if feature_type not in self.kernels: kernel_info = self.kernel_infos.get(feature_type)
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}") if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
sample_inputs = kernel_info.sample_inputs_fn()
sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
if not filter_metadata: if not filter_metadata:
yield from sample_inputs yield from sample_inputs
else: else:
......
import dataclasses
import functools import functools
import itertools import itertools
import math import math
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from _pytest.mark.structures import MarkDecorator
from common_utils import cycle_over from common_utils import cycle_over
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
InfoBase,
make_bounding_box_loaders, make_bounding_box_loaders,
make_image_loader, make_image_loader,
make_image_loaders, make_image_loaders,
make_mask_loaders, make_mask_loaders,
make_video_loaders, make_video_loaders,
mark_framework_limitation,
TestMark,
VALID_EXTRA_DIMS, VALID_EXTRA_DIMS,
) )
from torchvision.prototype import features from torchvision.prototype import features
...@@ -29,51 +27,35 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value ...@@ -29,51 +27,35 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
TestID = Tuple[Optional[str], str] class KernelInfo(InfoBase):
def __init__(
self,
@dataclasses.dataclass kernel,
class TestMark: *,
test_id: TestID # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
mark: MarkDecorator # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True kernel_name=None,
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
# should not include extensive parameter combinations to keep to overall test count moderate.
@dataclasses.dataclass sample_inputs_fn,
class KernelInfo: # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
kernel: Callable # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
# not include extensive parameter combinations to keep to overall test count moderate. # a tensor by `assert_close`. If omitted, no reference test will be performed.
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]] reference_fn=None,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then # values to be tested. If not specified, `sample_inputs_fn` will be used.
kernel_name: str = dataclasses.field(default=None) reference_inputs_fn=None,
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take # See InfoBase
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen test_marks=None,
# inside the function. It should return a tensor or to be more precise an object that can be compared to a # See InfoBase
# tensor by `assert_close`. If omitted, no reference test will be performed. closeness_kwargs=None,
reference_fn: Optional[Callable] = None ):
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
# values to be tested. If not specified, `sample_inputs_fn` will be used. self.kernel = kernel
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None self.sample_inputs_fn = sample_inputs_fn
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. self.reference_fn = reference_fn
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) self.reference_inputs_fn = reference_inputs_fn
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
...@@ -97,16 +79,6 @@ def pil_reference_wrapper(pil_kernel): ...@@ -97,16 +79,6 @@ def pil_reference_wrapper(pil_kernel):
return wrapper return wrapper
def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time.
return TestMark(test_id, pytest.mark.skip(reason=reason))
def xfail_jit_python_scalar_arg(name, *, reason=None): def xfail_jit_python_scalar_arg(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting" reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark( return TestMark(
......
import functools
import math import math
import os import os
...@@ -27,7 +26,7 @@ def script(fn): ...@@ -27,7 +26,7 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)): def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
if condition is None: if condition is None:
def condition(info): def condition(info):
...@@ -41,7 +40,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n ...@@ -41,7 +40,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
elif len(parts) == 2: elif len(parts) == 2:
test_class_name, test_function_name = parts test_class_name, test_function_name = parts
else: else:
raise pytest.UsageError("Unable to parse the test class and test name from test function") raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
test_id = (test_class_name, test_function_name) test_id = (test_class_name, test_function_name)
argnames = ("info", "args_kwargs") argnames = ("info", "args_kwargs")
...@@ -51,7 +50,6 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n ...@@ -51,7 +50,6 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
continue continue
args_kwargs = list(args_kwargs_fn(info)) args_kwargs = list(args_kwargs_fn(info))
name = name_fn(info)
idx_field_len = len(str(len(args_kwargs))) idx_field_len = len(str(len(args_kwargs)))
for idx, args_kwargs_ in enumerate(args_kwargs): for idx, args_kwargs_ in enumerate(args_kwargs):
...@@ -60,7 +58,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n ...@@ -60,7 +58,7 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
info, info,
args_kwargs_, args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_), marks=info.get_marks(test_id, args_kwargs_),
id=f"{name}-{idx:0{idx_field_len}}", id=f"{info.id}-{idx:0{idx_field_len}}",
) )
) )
...@@ -70,14 +68,11 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n ...@@ -70,14 +68,11 @@ def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, n
class TestKernels: class TestKernels:
make_kernel_args_kwargs_parametrization = functools.partial( sample_inputs = make_info_args_kwargs_parametrization(
make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name
)
sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS, KERNEL_INFOS,
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(), args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
) )
reference_inputs = make_kernel_args_kwargs_parametrization( reference_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS, KERNEL_INFOS,
args_kwargs_fn=lambda info: info.reference_inputs_fn(), args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None, condition=lambda info: info.reference_fn is not None,
...@@ -208,10 +203,7 @@ def spy_on(mocker): ...@@ -208,10 +203,7 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
make_dispatcher_args_kwargs_parametrization = functools.partial( image_sample_inputs = make_info_args_kwargs_parametrization(
make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__
)
image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization(
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels, condition=lambda info: features.Image in info.kernels,
...@@ -251,13 +243,13 @@ class TestDispatchers: ...@@ -251,13 +243,13 @@ class TestDispatchers:
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_feature)
kernel_info = info.kernel_infos[features.Image] kernel_info = info.kernel_infos[features.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.kernel_name) spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs) info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once() spy.assert_called_once()
@make_dispatcher_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: info.pil_kernel_info is not None, condition=lambda info: info.pil_kernel_info is not None,
...@@ -271,22 +263,23 @@ class TestDispatchers: ...@@ -271,22 +263,23 @@ class TestDispatchers:
image_pil = F.to_image_pil(image_feature) image_pil = F.to_image_pil(image_feature)
pil_kernel_info = info.pil_kernel_info pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.kernel_name) spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
info.dispatcher(image_pil, *other_args, **kwargs) info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once() spy.assert_called_once()
@make_dispatcher_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(), args_kwargs_fn=lambda info: info.sample_inputs(),
) )
def test_dispatch_feature(self, info, args_kwargs, spy_on): def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load() (feature, *other_args), kwargs = args_kwargs.load()
method = getattr(feature, info.method_name) method_name = info.id
method = getattr(feature, method_name)
feature_type = type(feature) feature_type = type(feature)
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{info.method_name}") spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}")
info.dispatcher(feature, *other_args, **kwargs) info.dispatcher(feature, *other_args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment