Unverified Commit cc1f9a2c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] mark the wanvace lora tester flaky (#11883)

* mark wanvace lora tests as flaky

* ability to apply is_flaky at a class-level

* update

* increase max_attempt.

* increase attemtp.
parent 737d7fc3
......@@ -994,10 +994,10 @@ def pytest_terminal_summary_main(tr, id):
config.option.tbstyle = orig_tbstyle
# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
To decorate flaky tests. They will be retried on failures.
To decorate flaky tests (methods or entire classes). They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
......@@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
etc.)
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def decorator(obj):
# If decorating a class, wrap each test method on it
if inspect.isclass(obj):
for attr_name, attr_value in list(obj.__dict__.items()):
if callable(attr_value) and attr_name.startswith("test"):
# recursively decorate the method
setattr(obj, attr_name, decorator(attr_value))
return obj
# Otherwise we're decorating a single test function / method
@functools.wraps(obj)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
return obj(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
msg = (
f"[FLAKY] {description or obj.__name__!r} "
f"failed on attempt {retry_count}/{max_attempts}: {err}"
)
print(msg, file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return obj(*args, **kwargs)
return wrapper
......
......@@ -46,6 +46,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
......@@ -217,6 +218,5 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Lora outputs should match.",
)
@is_flaky
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment