"tests/longformer/test_modeling_longformer.py" did not exist on "e78c1103385f2d2f9cd4980f61a8e71baa655356"
Unverified Commit 57e6464a authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Update all require decorators to use skipUnless when possible (#16999)

parent e952e049
......@@ -203,10 +203,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
if not _run_slow_tests:
return unittest.skip("test is slow")(test_case)
else:
return test_case
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def tooslow(test_case):
......@@ -227,10 +224,7 @@ def custom_tokenizers(test_case):
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
environment variable to a truthy value to run them.
"""
if not _run_custom_tokenizers:
return unittest.skip("test of custom tokenizers")(test_case)
else:
return test_case
return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
def require_git_lfs(test_case):
......@@ -240,34 +234,22 @@ def require_git_lfs(test_case):
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
variable to a truthy value to run them.
"""
if not _run_git_lfs_tests:
return unittest.skip("test of git lfs workflow")(test_case)
else:
return test_case
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
if not is_rjieba_available():
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case
return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
def require_tf2onnx(test_case):
if not is_tf2onnx_available():
return unittest.skip("test requires tf2onnx")(test_case)
else:
return test_case
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
else:
return test_case
return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
def require_timm(test_case):
......@@ -277,10 +259,7 @@ def require_timm(test_case):
These tests are skipped when Timm isn't installed.
"""
if not is_timm_available():
return unittest.skip("test requires Timm")(test_case)
else:
return test_case
return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
def require_torch(test_case):
......@@ -290,10 +269,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_scatter(test_case):
......@@ -303,10 +279,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed.
"""
if not is_scatter_available():
return unittest.skip("test requires PyTorch scatter")(test_case)
else:
return test_case
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)
def require_tensorflow_probability(test_case):
......@@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case):
These tests are skipped when TensorFlow probability isn't installed.
"""
if not is_tensorflow_probability_available():
return unittest.skip("test requires TensorFlow probability")(test_case)
else:
return test_case
return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
test_case
)
def require_torchaudio(test_case):
"""
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
"""
if not is_torchaudio_available():
return unittest.skip("test requires torchaudio")(test_case)
else:
return test_case
return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
def require_tf(test_case):
"""
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
"""
if not is_tf_available():
return unittest.skip("test requires TensorFlow")(test_case)
else:
return test_case
return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
if not is_flax_available():
test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
def require_sentencepiece(test_case):
"""
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
"""
if not is_sentencepiece_available():
return unittest.skip("test requires SentencePiece")(test_case)
else:
return test_case
return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
def require_scipy(test_case):
"""
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
"""
if not is_scipy_available():
return unittest.skip("test requires Scipy")(test_case)
else:
return test_case
return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
def require_tokenizers(test_case):
"""
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
"""
if not is_tokenizers_available():
return unittest.skip("test requires tokenizers")(test_case)
else:
return test_case
return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
if not is_pandas_available():
return unittest.skip("test requires pandas")(test_case)
else:
return test_case
return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
def require_pytesseract(test_case):
"""
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
"""
if not is_pytesseract_available():
return unittest.skip("test requires PyTesseract")(test_case)
else:
return test_case
return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
def require_scatter(test_case):
......@@ -406,10 +355,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
if not is_scatter_available():
return unittest.skip("test requires PyTorch Scatter")(test_case)
else:
return test_case
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)
def require_pytorch_quantization(test_case):
......@@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case):
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
Quantization Toolkit isn't installed.
"""
if not is_pytorch_quantization_available():
return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case)
else:
return test_case
return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
test_case
)
def require_vision(test_case):
......@@ -428,30 +373,21 @@ def require_vision(test_case):
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
installed.
"""
if not is_vision_available():
return unittest.skip("test requires vision")(test_case)
else:
return test_case
return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
def require_ftfy(test_case):
"""
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
"""
if not is_ftfy_available():
return unittest.skip("test requires ftfy")(test_case)
else:
return test_case
return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
def require_spacy(test_case):
"""
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
"""
if not is_spacy_available():
return unittest.skip("test requires spacy")(test_case)
else:
return test_case
return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
def require_torch_multi_gpu(test_case):
......@@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case):
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
else:
return test_case
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
def require_torch_non_multi_gpu(test_case):
......@@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case):
import torch
if torch.cuda.device_count() > 1:
return unittest.skip("test requires 0 or 1 GPU")(test_case)
else:
return test_case
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
def require_torch_up_to_2_gpus(test_case):
......@@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case):
import torch
if torch.cuda.device_count() > 2:
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
else:
return test_case
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if not is_torch_tpu_available():
return unittest.skip("test requires PyTorch TPU")
else:
return test_case
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
if is_torch_available():
......@@ -533,42 +457,31 @@ else:
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
if torch_device != "cuda":
return unittest.skip("test requires CUDA")(test_case)
else:
return test_case
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_bf16(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
else:
return test_case
return unittest.skipUnless(
is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
)(test_case)
def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
else:
return test_case
return unittest.skipUnless(
is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
)(test_case)
def require_detectron2(test_case):
"""Decorator marking a test that requires detectron2."""
if not is_detectron2_available():
return unittest.skip("test requires `detectron2`")(test_case)
else:
return test_case
return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not is_faiss_available():
return unittest.skip("test requires `faiss`")(test_case)
else:
return test_case
return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
def require_optuna(test_case):
......@@ -578,10 +491,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed.
"""
if not is_optuna_available():
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
def require_ray(test_case):
......@@ -591,10 +501,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed.
"""
if not is_ray_available():
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
def require_sigopt(test_case):
......@@ -604,10 +511,7 @@ def require_sigopt(test_case):
These tests are skipped when SigOpt isn't installed.
"""
if not is_sigopt_available():
return unittest.skip("test requires SigOpt")(test_case)
else:
return test_case
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
def require_wandb(test_case):
......@@ -617,10 +521,7 @@ def require_wandb(test_case):
These tests are skipped when wandb isn't installed.
"""
if not is_wandb_available():
return unittest.skip("test requires wandb")(test_case)
else:
return test_case
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
def require_soundfile(test_case):
......@@ -630,80 +531,56 @@ def require_soundfile(test_case):
These tests are skipped when soundfile isn't installed.
"""
if not is_soundfile_availble():
return unittest.skip("test requires soundfile")(test_case)
else:
return test_case
return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
"""
if not is_phonemizer_available():
return unittest.skip("test requires phonemizer")(test_case)
else:
return test_case
return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
def require_pyctcdecode(test_case):
"""
Decorator marking a test that requires pyctcdecode
"""
if not is_pyctcdecode_available():
return unittest.skip("test requires pyctcdecode")(test_case)
else:
return test_case
return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
def require_librosa(test_case):
"""
Decorator marking a test that requires librosa
"""
if not is_librosa_available():
return unittest.skip("test requires librosa")(test_case)
else:
return test_case
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
def cmd_exists(cmd):
......@@ -714,10 +591,7 @@ def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
if not cmd_exists("/usr/bin/time"):
return unittest.skip("test requires /usr/bin/time")(test_case)
else:
return test_case
return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
def get_gpu_count():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment