Unverified Commit d2d448c7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for the output types of prototype functional dispatchers (#7118)

parent 01d138d8
......@@ -112,6 +112,15 @@ skip_dispatch_datapoint = TestMark(
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
)
multi_crop_skips = [
TestMark(
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)
def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
......@@ -404,7 +413,7 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
),
DispatcherInfo(
......@@ -415,7 +424,7 @@ DISPATCHER_INFOS = [
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
......
......@@ -362,6 +362,16 @@ class TestDispatchers:
spy.assert_called_once()
@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)
output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
......@@ -381,6 +391,22 @@ class TestDispatchers:
spy.assert_called_once()
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_pil_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint)
output = info.dispatcher(image_pil, *other_args, **kwargs)
assert isinstance(output, PIL.Image.Image)
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
......@@ -397,6 +423,17 @@ class TestDispatchers:
spy.assert_called_once()
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_datapoint_output_type(self, info, args_kwargs):
(datapoint, *other_args), kwargs = args_kwargs.load()
output = info.dispatcher(datapoint, *other_args, **kwargs)
assert isinstance(output, type(datapoint))
@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment