Unverified Commit e3f7baaf authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test for dispatcher kernel signature consistency (#6904)



* add test for dispatcher kernel signature consistency

* add dispatcher feature signature consistency test

* fix error message
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4f3a000b
import inspect
import math import math
import os import os
import re import re
from typing import get_type_hints
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
...@@ -314,6 +317,63 @@ class TestDispatchers: ...@@ -314,6 +317,63 @@ class TestDispatchers:
spy.assert_called_once() spy.assert_called_once()
@pytest.mark.parametrize(
("dispatcher_info", "feature_type", "kernel_info"),
[
pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}")
for dispatcher_info in DISPATCHER_INFOS
for feature_type, kernel_info in dispatcher_info.kernel_infos.items()
],
)
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be
# explicit passed to the kernel.
feature_type_metadata = feature_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata]
dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
try:
# In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
# dispatcher parameters that have no kernel equivalent while keeping the order intact.
while dispatcher_param.name != kernel_param.name:
dispatcher_param = next(dispatcher_params)
except StopIteration:
raise AssertionError(
f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
) from None
assert dispatcher_param == kernel_param
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info):
try:
feature_method = getattr(features._Feature, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")
dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
feature_signature = inspect.signature(feature_method)
feature_params = list(feature_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined,
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
# concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method)
for param in feature_params:
param._annotation = feature_annotations[param.name]
assert dispatcher_params == feature_params
@pytest.mark.parametrize( @pytest.mark.parametrize(
("alias", "target"), ("alias", "target"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment