Unverified Commit 6ebbdfe8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add JIT script tests for prototype functional dispatchers (#6606)



* add xfailed smoke tests for dispatchers

* also support old FunctionalInfo's

* try reduce memory consumption for CI

* fix sample inputs generation
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent dbbc5c8e
...@@ -202,3 +202,47 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): ...@@ -202,3 +202,47 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
# scriptable function test # scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol) torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def cache(fn):
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite buffer size,
but also caches exceptions.
.. warning::
Only use this on deterministic functions.
"""
sentinel = object()
out_cache = {}
exc_cache = {}
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key = args + tuple(kwargs.values())
out = out_cache.get(key, sentinel)
if out is not sentinel:
return out
exc = exc_cache.get(key, sentinel)
if exc is not sentinel:
raise exc
try:
out = fn(*args, **kwargs)
except Exception as exc:
exc_cache[key] = exc
raise exc
out_cache[key] = out
return out
return wrapper
@cache
def script(fn):
try:
return torch.jit.script(fn)
except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
import dataclasses
import functools
from typing import Callable, Dict, Type
import pytest
import torch
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import ArgsKwargs
from prototype_transforms_kernel_infos import KERNEL_INFOS
from test_prototype_transforms_functional import FUNCTIONAL_INFOS
from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
# Helper class to use the infos from the old framework for now tests
class PreloadedArgsKwargs(ArgsKwargs):
def load(self, device="cpu"):
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in self.args)
kwargs = {
keyword: arg.to(device) if isinstance(arg, torch.Tensor) else arg for keyword, arg in self.kwargs.items()
}
return args, kwargs
def preloaded_sample_inputs(args_kwargs):
for args, kwargs in args_kwargs:
yield PreloadedArgsKwargs(*args, **kwargs)
KERNEL_SAMPLE_INPUTS_FN_MAP.update(
{info.functional: functools.partial(preloaded_sample_inputs, info.sample_inputs()) for info in FUNCTIONAL_INFOS}
)
@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
def sample_inputs(self, *types):
for type in types or self.kernels.keys():
if type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
kernels={
features.Image: F.horizontal_flip_image_tensor,
features.BoundingBox: F.horizontal_flip_bounding_box,
features.Mask: F.horizontal_flip_mask,
},
),
DispatcherInfo(
F.resize,
kernels={
features.Image: F.resize_image_tensor,
features.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask,
},
),
DispatcherInfo(
F.affine,
kernels={
features.Image: F.affine_image_tensor,
features.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask,
},
),
DispatcherInfo(
F.vertical_flip,
kernels={
features.Image: F.vertical_flip_image_tensor,
features.BoundingBox: F.vertical_flip_bounding_box,
features.Mask: F.vertical_flip_mask,
},
),
DispatcherInfo(
F.rotate,
kernels={
features.Image: F.rotate_image_tensor,
features.BoundingBox: F.rotate_bounding_box,
features.Mask: F.rotate_mask,
},
),
DispatcherInfo(
F.crop,
kernels={
features.Image: F.crop_image_tensor,
features.BoundingBox: F.crop_bounding_box,
features.Mask: F.crop_mask,
},
),
DispatcherInfo(
F.resized_crop,
kernels={
features.Image: F.resized_crop_image_tensor,
features.BoundingBox: F.resized_crop_bounding_box,
features.Mask: F.resized_crop_mask,
},
),
DispatcherInfo(
F.pad,
kernels={
features.Image: F.pad_image_tensor,
features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask,
},
),
DispatcherInfo(
F.perspective,
kernels={
features.Image: F.perspective_image_tensor,
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
),
DispatcherInfo(
F.center_crop,
kernels={
features.Image: F.center_crop_image_tensor,
features.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask,
},
),
DispatcherInfo(
F.gaussian_blur,
kernels={
features.Image: F.gaussian_blur_image_tensor,
},
),
DispatcherInfo(
F.equalize,
kernels={
features.Image: F.equalize_image_tensor,
},
),
DispatcherInfo(
F.invert,
kernels={
features.Image: F.invert_image_tensor,
},
),
DispatcherInfo(
F.posterize,
kernels={
features.Image: F.posterize_image_tensor,
},
),
DispatcherInfo(
F.solarize,
kernels={
features.Image: F.solarize_image_tensor,
},
),
DispatcherInfo(
F.autocontrast,
kernels={
features.Image: F.autocontrast_image_tensor,
},
),
DispatcherInfo(
F.adjust_sharpness,
kernels={
features.Image: F.adjust_sharpness_image_tensor,
},
),
DispatcherInfo(
F.erase,
kernels={
features.Image: F.erase_image_tensor,
},
),
]
import itertools
import pytest
import torch.jit
from common_utils import cpu_and_gpu, script
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from torchvision.prototype import features
class TestCommon:
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
# FIXME: This is a hack to avoid undiagnosed memory issues in CI right now. The current working guess is
# that we run out of memory, because to many tensors are instantiated upfront. This should be solved by
# the loader architecture.
for idx, args_kwargs in itertools.islice(enumerate(info.sample_inputs(features.Image)), 10)
if features.Image in info.kernels
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_smoke(self, info, args_kwargs, device):
fn = script(info.dispatcher)
(image_feature, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_feature)
fn(image_simple_tensor, *other_args, **kwargs)
import pytest import pytest
import torch.testing import torch.testing
from common_utils import cpu_and_gpu, needs_cuda from common_utils import cpu_and_gpu, needs_cuda, script
from prototype_common_utils import assert_close from prototype_common_utils import assert_close
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -104,10 +104,7 @@ class TestCommon: ...@@ -104,10 +104,7 @@ class TestCommon:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device): def test_scripted_vs_eager(self, info, args_kwargs, device):
kernel_eager = info.kernel kernel_eager = info.kernel
try: kernel_scripted = script(kernel_eager)
kernel_scripted = torch.jit.script(kernel_eager)
except Exception as error:
raise AssertionError("Trying to `torch.jit.script` the kernel raised the error above.") from error
args, kwargs = args_kwargs.load(device) args, kwargs = args_kwargs.load(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment