"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5934873b8f4c1dda00a6271bc40fd2a45a1a918e"
Unverified Commit 8ec7a70f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow tolerances in transforms consistency checks (#6774)

parent c960273c
...@@ -12,6 +12,7 @@ import pytest ...@@ -12,6 +12,7 @@ import pytest
import torch import torch
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
assert_close,
assert_equal, assert_equal,
make_bounding_box, make_bounding_box,
make_detection_mask, make_detection_mask,
...@@ -40,6 +41,7 @@ class ConsistencyConfig: ...@@ -40,6 +41,7 @@ class ConsistencyConfig:
make_images_kwargs=None, make_images_kwargs=None,
supports_pil=True, supports_pil=True,
removed_params=(), removed_params=(),
closeness_kwargs=None,
): ):
self.prototype_cls = prototype_cls self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls self.legacy_cls = legacy_cls
...@@ -47,6 +49,7 @@ class ConsistencyConfig: ...@@ -47,6 +49,7 @@ class ConsistencyConfig:
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil self.supports_pil = supports_pil
self.removed_params = removed_params self.removed_params = removed_params
self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters # These are here since both the prototype and legacy transform need to be constructed with the same random parameters
...@@ -491,10 +494,14 @@ def test_signature_consistency(config): ...@@ -491,10 +494,14 @@ def test_signature_consistency(config):
assert prototype_kinds == legacy_kinds assert prototype_kinds == legacy_kinds
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True): def check_call_consistency(
prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
if images is None: if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS) images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
closeness_kwargs = closeness_kwargs or dict()
for image in images: for image in images:
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
...@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s ...@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`is_simple_tensor` path in `_transform`." f"`is_simple_tensor` path in `_transform`."
) from exc ) from exc
assert_equal( assert_close(
output_prototype_tensor, output_prototype_tensor,
output_legacy_tensor, output_legacy_tensor,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}", msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
) )
try: try:
...@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s ...@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`features.Image` path in `_transform`." f"`features.Image` path in `_transform`."
) from exc ) from exc
assert_equal( assert_close(
output_prototype_image, output_prototype_image,
output_prototype_tensor, output_prototype_tensor,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
**closeness_kwargs,
) )
if image.ndim == 3 and supports_pil: if image.ndim == 3 and supports_pil:
...@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s ...@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`PIL.Image.Image` path in `_transform`." f"`PIL.Image.Image` path in `_transform`."
) from exc ) from exc
assert_equal( assert_close(
output_prototype_pil, output_prototype_pil,
output_legacy_pil, output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}", msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
) )
...@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs): ...@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
legacy_transform, legacy_transform,
images=make_images(**config.make_images_kwargs), images=make_images(**config.make_images_kwargs),
supports_pil=config.supports_pil, supports_pil=config.supports_pil,
closeness_kwargs=config.closeness_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment