Unverified Commit 3c9ae0ac authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add more KernelInfo's and DispatcherInfo's (#6626)

* add KernelInfo for adjust_brightness

* add KernelInfo for adjust_contrast

* add KernelInfo for adjust_hue

* add KernelInfo for adjust_saturation

* add KernelInfo for clamp_bounding_box

* add KernelInfo for {five, ten}_crop_image_tensor as well as skip functionality

* add KernelInfo for normalize

* add KernelInfo for adjust_gamma

* cleanup

* add DispatcherInfo's for previously add KernelInfo's

* add dispatcher info for elastic
parent 658ca539
import dataclasses
from typing import Callable, Dict, Type
from typing import Callable, Dict, Sequence, Type
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......@@ -15,6 +15,11 @@ KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KE
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self._skips_map = {skip.test_name: skip for skip in self.skips}
def sample_inputs(self, *types):
for type in types or self.kernels.keys():
......@@ -23,6 +28,11 @@ class DispatcherInfo:
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
DISPATCHER_INFOS = [
DispatcherInfo(
......@@ -97,6 +107,14 @@ DISPATCHER_INFOS = [
features.Mask: F.perspective_mask,
},
),
DispatcherInfo(
F.elastic,
kernels={
features.Image: F.elastic_image_tensor,
features.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask,
},
),
DispatcherInfo(
F.center_crop,
kernels={
......@@ -153,4 +171,66 @@ DISPATCHER_INFOS = [
features.Image: F.erase_image_tensor,
},
),
DispatcherInfo(
F.adjust_brightness,
kernels={
features.Image: F.adjust_brightness_image_tensor,
},
),
DispatcherInfo(
F.adjust_contrast,
kernels={
features.Image: F.adjust_contrast_image_tensor,
},
),
DispatcherInfo(
F.adjust_gamma,
kernels={
features.Image: F.adjust_gamma_image_tensor,
},
),
DispatcherInfo(
F.adjust_hue,
kernels={
features.Image: F.adjust_hue_image_tensor,
},
),
DispatcherInfo(
F.adjust_saturation,
kernels={
features.Image: F.adjust_saturation_image_tensor,
},
),
DispatcherInfo(
F.five_crop,
kernels={
features.Image: F.five_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
],
),
DispatcherInfo(
F.ten_crop,
kernels={
features.Image: F.ten_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
],
),
DispatcherInfo(
F.normalize,
kernels={
features.Image: F.normalize_image_tensor,
},
),
]
......@@ -2,7 +2,7 @@ import dataclasses
import functools
import itertools
import math
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, Optional, Sequence
import numpy as np
import pytest
......@@ -17,6 +17,13 @@ from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = ["KernelInfo", "KERNEL_INFOS"]
@dataclasses.dataclass
class Skip:
test_name: str
reason: str
condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True
@dataclasses.dataclass
class KernelInfo:
kernel: Callable
......@@ -36,10 +43,18 @@ class KernelInfo:
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
self._skips_map = {skip.test_name: skip for skip in self.skips}
def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -1223,3 +1238,267 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_erase_image_tensor,
)
)
_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_BRIGHTNESS_FACTORS,
):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
KERNEL_INFOS.append(
KernelInfo(
F.adjust_brightness_image_tensor,
kernel_name="adjust_brightness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
)
)
_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_CONTRAST_FACTORS,
):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
KERNEL_INFOS.append(
KernelInfo(
F.adjust_contrast_image_tensor,
kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
)
)
_ADJUST_GAMMA_GAMMAS_GAINS = [
(0.5, 2.0),
(0.0, 1.0),
]
def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_GAMMA_GAMMAS_GAINS,
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
KERNEL_INFOS.append(
KernelInfo(
F.adjust_gamma_image_tensor,
kernel_name="adjust_gamma_image_tensor",
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
)
)
_ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_HUE_FACTORS,
):
yield ArgsKwargs(image_loader, hue_factor=hue_factor)
KERNEL_INFOS.append(
KernelInfo(
F.adjust_hue_image_tensor,
kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
)
)
_ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_SATURATION_FACTORS,
):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
KERNEL_INFOS.append(
KernelInfo(
F.adjust_saturation_image_tensor,
kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
)
)
def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size
)
KERNEL_INFOS.append(
KernelInfo(
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
)
)
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]
def _get_five_ten_crop_image_size(size):
if isinstance(size, int):
crop_height = crop_width = size
elif len(size) == 1:
crop_height = crop_width = size[0]
else:
crop_height, crop_width = size
return 2 * crop_height, 2 * crop_width
def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]):
yield ArgsKwargs(image_loader, size=size)
def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size)
def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
KERNEL_INFOS.extend(
[
KernelInfo(
F.five_crop_image_tensor,
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.ten_crop_image_tensor,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
]
)
_NORMALIZE_MEANS_STDS = [
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS,
):
yield ArgsKwargs(image_loader, mean=mean, std=std)
KERNEL_INFOS.append(
KernelInfo(
F.normalize_image_tensor,
kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor,
)
)
......@@ -26,6 +26,25 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
@pytest.fixture(autouse=True)
def maybe_skip(request):
# In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist
try:
callspec = request.node.callspec
except AttributeError:
return
try:
info = callspec.params["info"]
args_kwargs = callspec.params["args_kwargs"]
except KeyError:
return
info.maybe_skip(
test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu")
)
class TestKernels:
sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"),
......@@ -49,25 +68,6 @@ class TestKernels:
assert_close(actual, expected, **info.closeness_kwargs)
# TODO: We need this until the kernels below also have `KernelInfo`'s. If they do, `test_scripted_vs_eager` replaces
# this test for them.
@pytest.mark.parametrize(
"kernel",
[
F.adjust_brightness_image_tensor,
F.adjust_gamma_image_tensor,
F.adjust_hue_image_tensor,
F.adjust_saturation_image_tensor,
F.clamp_bounding_box,
F.five_crop_image_tensor,
F.normalize_image_tensor,
F.ten_crop_image_tensor,
],
ids=lambda kernel: kernel.__name__,
)
def test_scriptable(self, kernel):
script(kernel)
def _unbind_batch_dims(self, batched_tensor, *, data_dims):
if batched_tensor.ndim == data_dims:
return batched_tensor
......@@ -190,22 +190,13 @@ class TestDispatchers:
@pytest.mark.parametrize(
"dispatcher",
[
F.adjust_brightness,
F.adjust_contrast,
F.adjust_gamma,
F.adjust_hue,
F.adjust_saturation,
F.convert_color_space,
F.convert_image_dtype,
F.elastic_transform,
F.five_crop,
F.get_dimensions,
F.get_image_num_channels,
F.get_image_size,
F.get_spatial_size,
F.normalize,
F.rgb_to_grayscale,
F.ten_crop,
],
ids=lambda dispatcher: dispatcher.__name__,
)
......@@ -222,6 +213,7 @@ class TestDispatchers:
(F.vflip, F.vertical_flip),
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
]
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment