Unverified Commit 8ec7a70f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow tolerances in transforms consistency checks (#6774)

parent c960273c
......@@ -12,6 +12,7 @@ import pytest
import torch
from prototype_common_utils import (
ArgsKwargs,
assert_close,
assert_equal,
make_bounding_box,
make_detection_mask,
......@@ -40,6 +41,7 @@ class ConsistencyConfig:
make_images_kwargs=None,
supports_pil=True,
removed_params=(),
closeness_kwargs=None,
):
self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls
......@@ -47,6 +49,7 @@ class ConsistencyConfig:
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil
self.removed_params = removed_params
self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
......@@ -491,10 +494,14 @@ def test_signature_consistency(config):
assert prototype_kinds == legacy_kinds
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
def check_call_consistency(
prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
closeness_kwargs = closeness_kwargs or dict()
for image in images:
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
......@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`is_simple_tensor` path in `_transform`."
) from exc
assert_equal(
assert_close(
output_prototype_tensor,
output_legacy_tensor,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
)
try:
......@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`features.Image` path in `_transform`."
) from exc
assert_equal(
assert_close(
output_prototype_image,
output_prototype_tensor,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
**closeness_kwargs,
)
if image.ndim == 3 and supports_pil:
......@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`PIL.Image.Image` path in `_transform`."
) from exc
assert_equal(
assert_close(
output_prototype_pil,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
)
......@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
legacy_transform,
images=make_images(**config.make_images_kwargs),
supports_pil=config.supports_pil,
closeness_kwargs=config.closeness_kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment