Unverified Commit 60449a44 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype functional kernel scriptability tests (#5496)

* expand prototype functional scriptability tests

* remove obsolete skips
parent 95d41897
...@@ -199,20 +199,30 @@ def resize_bounding_box(): ...@@ -199,20 +199,30 @@ def resize_bounding_box():
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
class TestKernelsCommon: @pytest.mark.parametrize(
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name) "kernel",
def test_scriptable(self, functional_info): [
jit.script(functional_info.functional) pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items()
if not name.startswith("_")
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
],
)
def test_scriptable(kernel):
jit.script(kernel)
@pytest.mark.parametrize(
@pytest.mark.parametrize(
("functional_info", "sample_input"), ("functional_info", "sample_input"),
[ [
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs()) for idx, sample_input in enumerate(functional_info.sample_inputs())
], ],
) )
def test_eager_vs_scripted(self, functional_info, sample_input): def test_eager_vs_scripted(functional_info, sample_input):
eager = functional_info(sample_input) eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
......
import unittest.mock import unittest.mock
from typing import Dict, Any, Tuple, cast from typing import Dict, Any, Tuple
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor ...@@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return cast(torch.Tensor, one_hot(label, num_classes=num_categories)) return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment