"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3cff4cc58730409c68f8afa2f3b9c61efa0e85c6"
Unverified Commit 209bec46 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a decorator for flaky tests (#19498)

* Add a decorator for flaky tests

* Quality

* Don't break the rest

* Address review comments

* Fix test name

* Fix typo and print to stderr
parent 1973b771
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import collections import collections
import contextlib import contextlib
import functools
import inspect import inspect
import logging import logging
import os import os
...@@ -23,12 +24,13 @@ import shutil ...@@ -23,12 +24,13 @@ import shutil
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import time
import unittest import unittest
from collections.abc import Mapping from collections.abc import Mapping
from distutils.util import strtobool from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Union from typing import Iterator, List, Optional, Union
from unittest import mock from unittest import mock
import huggingface_hub import huggingface_hub
...@@ -1635,3 +1637,36 @@ class RequestCounter: ...@@ -1635,3 +1637,36 @@ class RequestCounter:
self.other_request_count += 1 self.other_request_count += 1
return self.old_request(method=method, **kwargs) return self.old_request(method=method, **kwargs)
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
"""
To decorate flaky tests. They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*):
If provided, will wait that number of seconds before retrying the test.
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return wrapper
return decorator
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
...@@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
) )
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
def prepare_batch(filename="train-batch.pt"): def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset") file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
......
...@@ -21,7 +21,9 @@ from datasets import load_dataset ...@@ -21,7 +21,9 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky,
is_librosa_available, is_librosa_available,
is_pt_flax_cross_test,
is_pyctcdecode_available, is_pyctcdecode_available,
require_flax, require_flax,
require_librosa, require_librosa,
...@@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(np.ones((1, 1024), dtype="f4")) outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
@is_pt_flax_cross_test
@is_flaky()
def test_equivalence_pt_to_flax(self):
super().test_equivalence_pt_to_flax()
@require_flax @require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase): class FlaxWav2Vec2UtilsTest(unittest.TestCase):
......
...@@ -26,7 +26,7 @@ from datasets import load_dataset ...@@ -26,7 +26,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
@is_flaky()
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs) self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment