Unverified Commit 209bec46 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a decorator for flaky tests (#19498)

* Add a decorator for flaky tests

* Quality

* Don't break the rest

* Address review comments

* Fix test name

* Fix typo and print to stderr
parent 1973b771
......@@ -14,6 +14,7 @@
import collections
import contextlib
import functools
import inspect
import logging
import os
......@@ -23,12 +24,13 @@ import shutil
import subprocess
import sys
import tempfile
import time
import unittest
from collections.abc import Mapping
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Iterator, List, Union
from typing import Iterator, List, Optional, Union
from unittest import mock
import huggingface_hub
......@@ -1635,3 +1637,36 @@ class RequestCounter:
self.other_request_count += 1
return self.old_request(method=method, **kwargs)
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
"""
To decorate flaky tests. They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*):
If provided, will wait that number of seconds before retrying the test.
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return wrapper
return decorator
......@@ -20,7 +20,7 @@ import unittest
from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
......@@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
......
......@@ -21,7 +21,9 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
is_pyctcdecode_available,
require_flax,
require_librosa,
......@@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@is_pt_flax_cross_test
@is_flaky()
def test_equivalence_pt_to_flax(self):
super().test_equivalence_pt_to_flax()
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
......
......@@ -26,7 +26,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester
......@@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
@is_flaky()
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment