Unverified Commit 658ca539 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup prototype transforms functional tests (#6622)

* cleanup prototype transforms functional tests

* fix

* oust local functions
parent f49edd3b
...@@ -205,12 +205,8 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): ...@@ -205,12 +205,8 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
def cache(fn): def cache(fn):
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite buffer size, """Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
but also caches exceptions. but this also caches exceptions.
.. warning::
Only use this on deterministic functions.
""" """
sentinel = object() sentinel = object()
out_cache = {} out_cache = {}
...@@ -238,11 +234,3 @@ def cache(fn): ...@@ -238,11 +234,3 @@ def cache(fn):
return out return out
return wrapper return wrapper
@cache
def script(fn):
try:
return torch.jit.script(fn)
except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
import dataclasses import dataclasses
import functools
from typing import Callable, Dict, Type from typing import Callable, Dict, Type
import pytest import pytest
import torch
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_common_utils import ArgsKwargs
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from test_prototype_transforms_functional import FUNCTIONAL_INFOS
from torchvision.prototype import features from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
...@@ -15,26 +11,6 @@ __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] ...@@ -15,26 +11,6 @@ __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS} KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
# Helper class to use the infos from the old framework for now tests
class PreloadedArgsKwargs(ArgsKwargs):
def load(self, device="cpu"):
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in self.args)
kwargs = {
keyword: arg.to(device) if isinstance(arg, torch.Tensor) else arg for keyword, arg in self.kwargs.items()
}
return args, kwargs
def preloaded_sample_inputs(args_kwargs):
for args, kwargs in args_kwargs:
yield PreloadedArgsKwargs(*args, **kwargs)
KERNEL_SAMPLE_INPUTS_FN_MAP.update(
{info.functional: functools.partial(preloaded_sample_inputs, info.sample_inputs()) for info in FUNCTIONAL_INFOS}
)
@dataclasses.dataclass @dataclasses.dataclass
class DispatcherInfo: class DispatcherInfo:
dispatcher: Callable dispatcher: Callable
......
import itertools
import pytest
import torch.jit
from common_utils import cpu_and_gpu, script
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from torchvision.prototype import features
class TestCommon:
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
# FIXME: This is a hack to avoid undiagnosed memory issues in CI right now. The current working guess is
# that we run out of memory, because to many tensors are instantiated upfront. This should be solved by
# the loader architecture.
for idx, args_kwargs in itertools.islice(enumerate(info.sample_inputs(features.Image)), 10)
if features.Image in info.kernels
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_smoke(self, info, args_kwargs, device):
fn = script(info.dispatcher)
(image_feature, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_feature)
fn(image_simple_tensor, *other_args, **kwargs)
...@@ -4,188 +4,233 @@ import os ...@@ -4,188 +4,233 @@ import os
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F import torch
from common_utils import cpu_and_gpu from common_utils import cache, cpu_and_gpu, needs_cuda
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image from prototype_common_utils import assert_close, make_bounding_boxes, make_image
from torch import jit from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
class FunctionalInfo: @cache
def __init__(self, name, *, sample_inputs_fn): def script(fn):
self.name = name try:
self.functional = getattr(F, name) return torch.jit.script(fn)
self._sample_inputs_fn = sample_inputs_fn except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
def sample_inputs(self):
yield from self._sample_inputs_fn()
def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], ArgsKwargs):
sample_input = args[0]
return self.functional(*sample_input.args, **sample_input.kwargs)
return self.functional(*args, **kwargs)
FUNCTIONAL_INFOS = []
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): class TestKernels:
FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) sample_inputs = pytest.mark.parametrize(
return sample_inputs_fn ("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"} for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs_fn())
],
def _distinct_callables(callable_names): )
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API
remove = set()
distinct = set()
for name in callable_names:
item = F.__dict__[name]
if item not in distinct:
distinct.add(item)
else:
remove.add(name)
callable_names -= remove
# create tuple and sort by name
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0])
def _get_distinct_kernels(): @sample_inputs
kernel_names = { @pytest.mark.parametrize("device", cpu_and_gpu())
name def test_scripted_vs_eager(self, info, args_kwargs, device):
for name, f in F.__dict__.items() kernel_eager = info.kernel
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES) kernel_scripted = script(kernel_eager)
}
return _distinct_callables(kernel_names)
args, kwargs = args_kwargs.load(device)
def _get_distinct_midlevels(): actual = kernel_scripted(*args, **kwargs)
midlevel_names = { expected = kernel_eager(*args, **kwargs)
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(midlevel_names)
assert_close(actual, expected, **info.closeness_kwargs)
@pytest.mark.parametrize( # TODO: We need this until the kernels below also have `KernelInfo`'s. If they do, `test_scripted_vs_eager` replaces
"kernel", # this test for them.
[ @pytest.mark.parametrize(
pytest.param(kernel, id=name) "kernel",
for name, kernel in _get_distinct_kernels() [
if not name.endswith("_image_pil") and name not in {"to_image_tensor"} F.adjust_brightness_image_tensor,
], F.adjust_gamma_image_tensor,
) F.adjust_hue_image_tensor,
def test_scriptable_kernel(kernel): F.adjust_saturation_image_tensor,
jit.script(kernel) # TODO: pass data through it F.clamp_bounding_box,
F.five_crop_image_tensor,
F.normalize_image_tensor,
F.ten_crop_image_tensor,
],
ids=lambda kernel: kernel.__name__,
)
def test_scriptable(self, kernel):
script(kernel)
def _unbind_batch_dims(self, batched_tensor, *, data_dims):
if batched_tensor.ndim == data_dims:
return batched_tensor
return [self._unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)]
def _stack_batch_dims(self, unbound_tensor):
if isinstance(unbound_tensor[0], torch.Tensor):
return torch.stack(unbound_tensor)
return torch.stack([self._stack_batch_dims(t) for t in unbound_tensor])
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
features.Image: 3,
features.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
features.Mask: 2,
}.get(feature_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}."
) from None
elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.")
elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.")
actual = info.kernel(batched_input, *other_args, **kwargs)
single_inputs = self._unbind_batch_dims(batched_input, data_dims=data_dims)
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
expected = self._stack_batch_dims(single_outputs)
assert_close(actual, expected, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
if input.numel() == 0:
pytest.skip("The input has a degenerate shape.")
input_version = input._version
output = info.kernel(input, *other_args, **kwargs)
assert output is not input or output._version == input_version
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
assert output.dtype == input.dtype
assert output.device == input.device
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None
],
)
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
expected = info.reference_fn(*args, **kwargs)
@pytest.mark.parametrize( assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs)
"midlevel",
[
pytest.param(midlevel, id=name)
for name, midlevel in _get_distinct_midlevels()
if name
not in {
"InterpolationMode",
"decode_image_with_pil",
"decode_video_with_av",
"pil_to_tensor",
"to_grayscale",
"to_pil_image",
"to_tensor",
}
],
)
def test_scriptable_midlevel(midlevel):
jit.script(midlevel) # TODO: pass data through it
# Test below is intended to test mid-level op vs low-level ops it calls class TestDispatchers:
# For example, resize -> resize_image_tensor, resize_bounding_boxes etc @pytest.mark.parametrize(
# TODO: Rewrite this tests as sample args may include more or less params ("info", "args_kwargs"),
# than needed by functions [
@pytest.mark.parametrize( pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
"func", for info in DISPATCHER_INFOS
[ for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
pytest.param(func, id=name) if features.Image in info.kernels
for name, func in F.__dict__.items() ],
if not name.startswith("_") and callable(func) )
# TODO: remove aliases @pytest.mark.parametrize("device", cpu_and_gpu())
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"}) def test_scripted_smoke(self, info, args_kwargs, device):
and name dispatcher = script(info.dispatcher)
not in {
"to_image_tensor",
"InterpolationMode",
"decode_video_with_av",
"crop",
"perspective",
"elastic_transform",
"elastic",
}
# We skip 'crop' due to missing 'height' and 'width'
# We skip 'perspective' as it requires different input args than perspective_image_tensor etc
# Skip 'elastic', TODO: inspect why test is failing
],
)
def test_functional_mid_level(func):
finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name]
for finfo in finfos:
for sample_input in finfo.sample_inputs():
expected = finfo(sample_input)
kwargs = dict(sample_input.kwargs)
for key in ["format", "image_size"]:
if key in kwargs:
del kwargs[key]
output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(
output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}"
)
break
(image_feature, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_feature)
@pytest.mark.parametrize( dispatcher(image_simple_tensor, *other_args, **kwargs)
("functional_info", "sample_input"),
[
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted) # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them.
@pytest.mark.parametrize(
"dispatcher",
[
F.adjust_brightness,
F.adjust_contrast,
F.adjust_gamma,
F.adjust_hue,
F.adjust_saturation,
F.convert_color_space,
F.convert_image_dtype,
F.elastic_transform,
F.five_crop,
F.get_dimensions,
F.get_image_num_channels,
F.get_image_size,
F.get_spatial_size,
F.normalize,
F.rgb_to_grayscale,
F.ten_crop,
],
ids=lambda dispatcher: dispatcher.__name__,
)
def test_scriptable(self, dispatcher):
script(dispatcher)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("functional_info", "sample_input"), ("alias", "target"),
[ [
pytest.param( pytest.param(alias, target, id=alias.__name__)
functional_info, for alias, target in [
sample_input, (F.hflip, F.horizontal_flip),
id=f"{functional_info.name}-{idx}", (F.vflip, F.vertical_flip),
) (F.get_image_num_channels, F.get_num_channels),
for functional_info in FUNCTIONAL_INFOS (F.to_pil_image, F.to_image_pil),
for idx, sample_input in enumerate(functional_info.sample_inputs()) ]
], ],
) )
def test_dtype_consistency(functional_info, sample_input): def test_alias(alias, target):
(input, *other_args), kwargs = sample_input assert alias is target
output = functional_info.functional(input, *other_args, **kwargs)
assert output.dtype == input.dtype # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
......
import pytest
import torch.testing
from common_utils import cpu_and_gpu, needs_cuda, script
from prototype_common_utils import assert_close
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F
def test_coverage():
tested = {info.kernel_name for info in KERNEL_INFOS}
exposed = {
name
for name, kernel in F.__dict__.items()
if callable(kernel)
and any(
name.endswith(f"_{feature_name}")
for feature_name in {
"bounding_box",
"image_tensor",
"label",
"mask",
}
)
and name not in {"to_image_tensor"}
# TODO: The list below should be quickly reduced in the transition period. There is nothing that prevents us
# from adding `KernelInfo`'s for these kernels other than time.
and name
not in {
"adjust_brightness_image_tensor",
"adjust_contrast_image_tensor",
"adjust_gamma_image_tensor",
"adjust_hue_image_tensor",
"adjust_saturation_image_tensor",
"clamp_bounding_box",
"five_crop_image_tensor",
"normalize_image_tensor",
"ten_crop_image_tensor",
}
}
needlessly_ignored = tested - exposed
if needlessly_ignored:
raise pytest.UsageError(
f"The kernel(s) {sequence_to_str(sorted(needlessly_ignored), separate_last='and ')} "
f"have an associated `KernelInfo` but are ignored by this test."
)
untested = exposed - tested
if untested:
raise AssertionError(
f"The kernel(s) {sequence_to_str(sorted(untested), separate_last='and ')} "
f"are exposed through `torchvision.prototype.transforms.functional`, but are not tested. "
f"Please add a `KernelInfo` to the `KERNEL_INFOS` list in `test/prototype_transforms_kernel_infos.py`."
)
class TestCommon:
sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs_fn())
],
)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device):
kernel_eager = info.kernel
kernel_scripted = script(kernel_eager)
args, kwargs = args_kwargs.load(device)
actual = kernel_scripted(*args, **kwargs)
expected = kernel_eager(*args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
def unbind_batch_dims(batched_tensor, *, data_dims):
if batched_tensor.ndim == data_dims:
return batched_tensor
return [unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)]
def stack_batch_dims(unbound_tensor):
if isinstance(unbound_tensor[0], torch.Tensor):
return torch.stack(unbound_tensor)
return torch.stack([stack_batch_dims(t) for t in unbound_tensor])
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
features.Image: 3,
features.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
features.Mask: 2,
}.get(feature_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}."
) from None
elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.")
elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.")
actual = info.kernel(batched_input, *other_args, **kwargs)
single_inputs = unbind_batch_dims(batched_input, data_dims=data_dims)
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
expected = stack_batch_dims(single_outputs)
assert_close(actual, expected, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
if input.numel() == 0:
pytest.skip("The input has a degenerate shape.")
input_version = input._version
output = info.kernel(input, *other_args, **kwargs)
assert output is not input or output._version == input_version
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
assert output.dtype == input.dtype
assert output.device == input.device
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None
],
)
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
expected = info.reference_fn(*args, **kwargs)
assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment