"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3f6add8bab8245f819a8950b7c5d09aff9366e32"
Unverified Commit 57e6464a authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Update all require decorators to use skipUnless when possible (#16999)

parent e952e049
...@@ -203,10 +203,7 @@ def slow(test_case): ...@@ -203,10 +203,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
""" """
if not _run_slow_tests: return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
return unittest.skip("test is slow")(test_case)
else:
return test_case
def tooslow(test_case): def tooslow(test_case):
...@@ -227,10 +224,7 @@ def custom_tokenizers(test_case): ...@@ -227,10 +224,7 @@ def custom_tokenizers(test_case):
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
environment variable to a truthy value to run them. environment variable to a truthy value to run them.
""" """
if not _run_custom_tokenizers: return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
return unittest.skip("test of custom tokenizers")(test_case)
else:
return test_case
def require_git_lfs(test_case): def require_git_lfs(test_case):
...@@ -240,34 +234,22 @@ def require_git_lfs(test_case): ...@@ -240,34 +234,22 @@ def require_git_lfs(test_case):
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
variable to a truthy value to run them. variable to a truthy value to run them.
""" """
if not _run_git_lfs_tests: return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
return unittest.skip("test of git lfs workflow")(test_case)
else:
return test_case
def require_rjieba(test_case): def require_rjieba(test_case):
""" """
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
""" """
if not is_rjieba_available(): return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case
def require_tf2onnx(test_case): def require_tf2onnx(test_case):
if not is_tf2onnx_available(): return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
return unittest.skip("test requires tf2onnx")(test_case)
else:
return test_case
def require_onnx(test_case): def require_onnx(test_case):
if not is_onnx_available(): return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
return unittest.skip("test requires ONNX")(test_case)
else:
return test_case
def require_timm(test_case): def require_timm(test_case):
...@@ -277,10 +259,7 @@ def require_timm(test_case): ...@@ -277,10 +259,7 @@ def require_timm(test_case):
These tests are skipped when Timm isn't installed. These tests are skipped when Timm isn't installed.
""" """
if not is_timm_available(): return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
return unittest.skip("test requires Timm")(test_case)
else:
return test_case
def require_torch(test_case): def require_torch(test_case):
...@@ -290,10 +269,7 @@ def require_torch(test_case): ...@@ -290,10 +269,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed. These tests are skipped when PyTorch isn't installed.
""" """
if not is_torch_available(): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case
def require_torch_scatter(test_case): def require_torch_scatter(test_case):
...@@ -303,10 +279,7 @@ def require_torch_scatter(test_case): ...@@ -303,10 +279,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed. These tests are skipped when PyTorch scatter isn't installed.
""" """
if not is_scatter_available(): return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)
return unittest.skip("test requires PyTorch scatter")(test_case)
else:
return test_case
def require_tensorflow_probability(test_case): def require_tensorflow_probability(test_case):
...@@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case): ...@@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case):
These tests are skipped when TensorFlow probability isn't installed. These tests are skipped when TensorFlow probability isn't installed.
""" """
if not is_tensorflow_probability_available(): return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
return unittest.skip("test requires TensorFlow probability")(test_case) test_case
else: )
return test_case
def require_torchaudio(test_case): def require_torchaudio(test_case):
""" """
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed. Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
""" """
if not is_torchaudio_available(): return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
return unittest.skip("test requires torchaudio")(test_case)
else:
return test_case
def require_tf(test_case): def require_tf(test_case):
""" """
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed. Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
""" """
if not is_tf_available(): return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
return unittest.skip("test requires TensorFlow")(test_case)
else:
return test_case
def require_flax(test_case): def require_flax(test_case):
""" """
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
""" """
if not is_flax_available(): return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case
def require_sentencepiece(test_case): def require_sentencepiece(test_case):
""" """
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
""" """
if not is_sentencepiece_available(): return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
return unittest.skip("test requires SentencePiece")(test_case)
else:
return test_case
def require_scipy(test_case): def require_scipy(test_case):
""" """
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
""" """
if not is_scipy_available(): return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
return unittest.skip("test requires Scipy")(test_case)
else:
return test_case
def require_tokenizers(test_case): def require_tokenizers(test_case):
""" """
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
""" """
if not is_tokenizers_available(): return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
return unittest.skip("test requires tokenizers")(test_case)
else:
return test_case
def require_pandas(test_case): def require_pandas(test_case):
""" """
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
""" """
if not is_pandas_available(): return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
return unittest.skip("test requires pandas")(test_case)
else:
return test_case
def require_pytesseract(test_case): def require_pytesseract(test_case):
""" """
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed. Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
""" """
if not is_pytesseract_available(): return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
return unittest.skip("test requires PyTesseract")(test_case)
else:
return test_case
def require_scatter(test_case): def require_scatter(test_case):
...@@ -406,10 +355,7 @@ def require_scatter(test_case): ...@@ -406,10 +355,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed. installed.
""" """
if not is_scatter_available(): return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)
return unittest.skip("test requires PyTorch Scatter")(test_case)
else:
return test_case
def require_pytorch_quantization(test_case): def require_pytorch_quantization(test_case):
...@@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case): ...@@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case):
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
Quantization Toolkit isn't installed. Quantization Toolkit isn't installed.
""" """
if not is_pytorch_quantization_available(): return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case) test_case
else: )
return test_case
def require_vision(test_case): def require_vision(test_case):
...@@ -428,30 +373,21 @@ def require_vision(test_case): ...@@ -428,30 +373,21 @@ def require_vision(test_case):
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
installed. installed.
""" """
if not is_vision_available(): return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
return unittest.skip("test requires vision")(test_case)
else:
return test_case
def require_ftfy(test_case): def require_ftfy(test_case):
""" """
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
""" """
if not is_ftfy_available(): return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
return unittest.skip("test requires ftfy")(test_case)
else:
return test_case
def require_spacy(test_case): def require_spacy(test_case):
""" """
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed. Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
""" """
if not is_spacy_available(): return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
return unittest.skip("test requires spacy")(test_case)
else:
return test_case
def require_torch_multi_gpu(test_case): def require_torch_multi_gpu(test_case):
...@@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case): ...@@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case):
import torch import torch
if torch.cuda.device_count() < 2: return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
return unittest.skip("test requires multiple GPUs")(test_case)
else:
return test_case
def require_torch_non_multi_gpu(test_case): def require_torch_non_multi_gpu(test_case):
...@@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case): ...@@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case):
import torch import torch
if torch.cuda.device_count() > 1: return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
return unittest.skip("test requires 0 or 1 GPU")(test_case)
else:
return test_case
def require_torch_up_to_2_gpus(test_case): def require_torch_up_to_2_gpus(test_case):
...@@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case): ...@@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case):
import torch import torch
if torch.cuda.device_count() > 2: return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
else:
return test_case
def require_torch_tpu(test_case): def require_torch_tpu(test_case):
""" """
Decorator marking a test that requires a TPU (in PyTorch). Decorator marking a test that requires a TPU (in PyTorch).
""" """
if not is_torch_tpu_available(): return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
return unittest.skip("test requires PyTorch TPU")
else:
return test_case
if is_torch_available(): if is_torch_available():
...@@ -533,42 +457,31 @@ else: ...@@ -533,42 +457,31 @@ else:
def require_torch_gpu(test_case): def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch.""" """Decorator marking a test that requires CUDA and PyTorch."""
if torch_device != "cuda": return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
return unittest.skip("test requires CUDA")(test_case)
else:
return test_case
def require_torch_bf16(test_case): def require_torch_bf16(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10.""" """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available(): return unittest.skipUnless(
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case) is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
else: )(test_case)
return test_case
def require_torch_tf32(test_case): def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7.""" """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available(): return unittest.skipUnless(
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case) is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
else: )(test_case)
return test_case
def require_detectron2(test_case): def require_detectron2(test_case):
"""Decorator marking a test that requires detectron2.""" """Decorator marking a test that requires detectron2."""
if not is_detectron2_available(): return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
return unittest.skip("test requires `detectron2`")(test_case)
else:
return test_case
def require_faiss(test_case): def require_faiss(test_case):
"""Decorator marking a test that requires faiss.""" """Decorator marking a test that requires faiss."""
if not is_faiss_available(): return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
return unittest.skip("test requires `faiss`")(test_case)
else:
return test_case
def require_optuna(test_case): def require_optuna(test_case):
...@@ -578,10 +491,7 @@ def require_optuna(test_case): ...@@ -578,10 +491,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed. These tests are skipped when optuna isn't installed.
""" """
if not is_optuna_available(): return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
def require_ray(test_case): def require_ray(test_case):
...@@ -591,10 +501,7 @@ def require_ray(test_case): ...@@ -591,10 +501,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed. These tests are skipped when Ray/tune isn't installed.
""" """
if not is_ray_available(): return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
def require_sigopt(test_case): def require_sigopt(test_case):
...@@ -604,10 +511,7 @@ def require_sigopt(test_case): ...@@ -604,10 +511,7 @@ def require_sigopt(test_case):
These tests are skipped when SigOpt isn't installed. These tests are skipped when SigOpt isn't installed.
""" """
if not is_sigopt_available(): return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
return unittest.skip("test requires SigOpt")(test_case)
else:
return test_case
def require_wandb(test_case): def require_wandb(test_case):
...@@ -617,10 +521,7 @@ def require_wandb(test_case): ...@@ -617,10 +521,7 @@ def require_wandb(test_case):
These tests are skipped when wandb isn't installed. These tests are skipped when wandb isn't installed.
""" """
if not is_wandb_available(): return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
return unittest.skip("test requires wandb")(test_case)
else:
return test_case
def require_soundfile(test_case): def require_soundfile(test_case):
...@@ -630,80 +531,56 @@ def require_soundfile(test_case): ...@@ -630,80 +531,56 @@ def require_soundfile(test_case):
These tests are skipped when soundfile isn't installed. These tests are skipped when soundfile isn't installed.
""" """
if not is_soundfile_availble(): return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
return unittest.skip("test requires soundfile")(test_case)
else:
return test_case
def require_deepspeed(test_case): def require_deepspeed(test_case):
""" """
Decorator marking a test that requires deepspeed Decorator marking a test that requires deepspeed
""" """
if not is_deepspeed_available(): return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
def require_fairscale(test_case): def require_fairscale(test_case):
""" """
Decorator marking a test that requires fairscale Decorator marking a test that requires fairscale
""" """
if not is_fairscale_available(): return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
def require_apex(test_case): def require_apex(test_case):
""" """
Decorator marking a test that requires apex Decorator marking a test that requires apex
""" """
if not is_apex_available(): return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
return unittest.skip("test requires apex")(test_case)
else:
return test_case
def require_bitsandbytes(test_case): def require_bitsandbytes(test_case):
""" """
Decorator for bits and bytes (bnb) dependency Decorator for bits and bytes (bnb) dependency
""" """
if not is_bitsandbytes_available(): return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
return unittest.skip("test requires bnb")(test_case)
else:
return test_case
def require_phonemizer(test_case): def require_phonemizer(test_case):
""" """
Decorator marking a test that requires phonemizer Decorator marking a test that requires phonemizer
""" """
if not is_phonemizer_available(): return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
return unittest.skip("test requires phonemizer")(test_case)
else:
return test_case
def require_pyctcdecode(test_case): def require_pyctcdecode(test_case):
""" """
Decorator marking a test that requires pyctcdecode Decorator marking a test that requires pyctcdecode
""" """
if not is_pyctcdecode_available(): return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
return unittest.skip("test requires pyctcdecode")(test_case)
else:
return test_case
def require_librosa(test_case): def require_librosa(test_case):
""" """
Decorator marking a test that requires librosa Decorator marking a test that requires librosa
""" """
if not is_librosa_available(): return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
return unittest.skip("test requires librosa")(test_case)
else:
return test_case
def cmd_exists(cmd): def cmd_exists(cmd):
...@@ -714,10 +591,7 @@ def require_usr_bin_time(test_case): ...@@ -714,10 +591,7 @@ def require_usr_bin_time(test_case):
""" """
Decorator marking a test that requires `/usr/bin/time` Decorator marking a test that requires `/usr/bin/time`
""" """
if not cmd_exists("/usr/bin/time"): return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
return unittest.skip("test requires /usr/bin/time")(test_case)
else:
return test_case
def get_gpu_count(): def get_gpu_count():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment