Unverified Commit 22b19d57 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Tests] Add is flaky decorator (#5139)

* add is flaky decorator

* fix more
parent 787195fe
import functools
import importlib import importlib
import inspect import inspect
import io import io
...@@ -7,7 +8,9 @@ import os ...@@ -7,7 +8,9 @@ import os
import random import random
import re import re
import struct import struct
import sys
import tempfile import tempfile
import time
import unittest import unittest
import urllib.parse import urllib.parse
from contextlib import contextmanager from contextlib import contextmanager
...@@ -612,6 +615,43 @@ def pytest_terminal_summary_main(tr, id): ...@@ -612,6 +615,43 @@ def pytest_terminal_summary_main(tr, id):
config.option.tbstyle = orig_tbstyle config.option.tbstyle = orig_tbstyle
# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
To decorate flaky tests. They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*):
If provided, will wait that number of seconds before retrying the test.
description (`str`, *optional*):
A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
etc.)
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return wrapper
return decorator
# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
""" """
......
...@@ -30,6 +30,7 @@ from diffusers.utils import is_xformers_available ...@@ -30,6 +30,7 @@ from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
is_flaky,
skip_mps, skip_mps,
slow, slow,
torch_device, torch_device,
...@@ -156,9 +157,14 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -156,9 +157,14 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@is_flaky()
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=0.001) super().test_save_load_optional_components(expected_max_difference=0.001)
@is_flaky()
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(), torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed", reason="XFormers attention is only available with CUDA and `xformers` installed",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment