Unverified Commit 60449a44 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype functional kernel scriptability tests (#5496)

* expand prototype functional scriptability tests

* remove obsolete skips
parent 95d41897
...@@ -199,21 +199,31 @@ def resize_bounding_box(): ...@@ -199,21 +199,31 @@ def resize_bounding_box():
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
class TestKernelsCommon: @pytest.mark.parametrize(
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name) "kernel",
def test_scriptable(self, functional_info): [
jit.script(functional_info.functional) pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items()
@pytest.mark.parametrize( if not name.startswith("_")
("functional_info", "sample_input"), and callable(kernel)
[ and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") and "pil" not in name
for functional_info in FUNCTIONAL_INFOS ],
for idx, sample_input in enumerate(functional_info.sample_inputs()) )
], def test_scriptable(kernel):
) jit.script(kernel)
def test_eager_vs_scripted(self, functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) @pytest.mark.parametrize(
("functional_info", "sample_input"),
torch.testing.assert_close(eager, scripted) [
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted)
import unittest.mock import unittest.mock
from typing import Dict, Any, Tuple, cast from typing import Dict, Any, Tuple
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor ...@@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return cast(torch.Tensor, one_hot(label, num_classes=num_categories)) return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment