Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
57e6464a
"tests/longformer/test_modeling_longformer.py" did not exist on "e78c1103385f2d2f9cd4980f61a8e71baa655356"
Unverified
Commit
57e6464a
authored
Apr 29, 2022
by
Zachary Mueller
Committed by
GitHub
Apr 29, 2022
Browse files
Update all require decorators to use skipUnless when possible (#16999)
parent
e952e049
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
179 deletions
+53
-179
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+53
-179
No files found.
src/transformers/testing_utils.py
View file @
57e6464a
...
...
@@ -203,10 +203,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
if
not
_run_slow_tests
:
return
unittest
.
skip
(
"test is slow"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
_run_slow_tests
,
"test is slow"
)(
test_case
)
def
tooslow
(
test_case
):
...
...
@@ -227,10 +224,7 @@ def custom_tokenizers(test_case):
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
environment variable to a truthy value to run them.
"""
if
not
_run_custom_tokenizers
:
return
unittest
.
skip
(
"test of custom tokenizers"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
_run_custom_tokenizers
,
"test of custom tokenizers"
)(
test_case
)
def
require_git_lfs
(
test_case
):
...
...
@@ -240,34 +234,22 @@ def require_git_lfs(test_case):
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
variable to a truthy value to run them.
"""
if
not
_run_git_lfs_tests
:
return
unittest
.
skip
(
"test of git lfs workflow"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
_run_git_lfs_tests
,
"test of git lfs workflow"
)(
test_case
)
def
require_rjieba
(
test_case
):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
if
not
is_rjieba_available
():
return
unittest
.
skip
(
"test requires rjieba"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_rjieba_available
(),
"test requires rjieba"
)(
test_case
)
def
require_tf2onnx
(
test_case
):
if
not
is_tf2onnx_available
():
return
unittest
.
skip
(
"test requires tf2onnx"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_tf2onnx_available
(),
"test requires tf2onnx"
)(
test_case
)
def
require_onnx
(
test_case
):
if
not
is_onnx_available
():
return
unittest
.
skip
(
"test requires ONNX"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_onnx_available
(),
"test requires ONNX"
)(
test_case
)
def
require_timm
(
test_case
):
...
...
@@ -277,10 +259,7 @@ def require_timm(test_case):
These tests are skipped when Timm isn't installed.
"""
if
not
is_timm_available
():
return
unittest
.
skip
(
"test requires Timm"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_timm_available
(),
"test requires Timm"
)(
test_case
)
def
require_torch
(
test_case
):
...
...
@@ -290,10 +269,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed.
"""
if
not
is_torch_available
():
return
unittest
.
skip
(
"test requires PyTorch"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_torch_available
(),
"test requires PyTorch"
)(
test_case
)
def
require_torch_scatter
(
test_case
):
...
...
@@ -303,10 +279,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed.
"""
if
not
is_scatter_available
():
return
unittest
.
skip
(
"test requires PyTorch scatter"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_scatter_available
(),
"test requires PyTorch scatter"
)(
test_case
)
def
require_tensorflow_probability
(
test_case
):
...
...
@@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case):
These tests are skipped when TensorFlow probability isn't installed.
"""
if
not
is_tensorflow_probability_available
():
return
unittest
.
skip
(
"test requires TensorFlow probability"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_tensorflow_probability_available
(),
"test requires TensorFlow probability"
)(
test_case
)
def
require_torchaudio
(
test_case
):
"""
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
"""
if
not
is_torchaudio_available
():
return
unittest
.
skip
(
"test requires torchaudio"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_torchaudio_available
(),
"test requires torchaudio"
)(
test_case
)
def
require_tf
(
test_case
):
"""
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
"""
if
not
is_tf_available
():
return
unittest
.
skip
(
"test requires TensorFlow"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_tf_available
(),
"test requires TensorFlow"
)(
test_case
)
def
require_flax
(
test_case
):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
if
not
is_flax_available
():
test_case
=
unittest
.
skip
(
"test requires JAX & Flax"
)(
test_case
)
return
test_case
return
unittest
.
skipUnless
(
is_flax_available
(),
"test requires JAX & Flax"
)(
test_case
)
def
require_sentencepiece
(
test_case
):
"""
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
"""
if
not
is_sentencepiece_available
():
return
unittest
.
skip
(
"test requires SentencePiece"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_sentencepiece_available
(),
"test requires SentencePiece"
)(
test_case
)
def
require_scipy
(
test_case
):
"""
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
"""
if
not
is_scipy_available
():
return
unittest
.
skip
(
"test requires Scipy"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_scipy_available
(),
"test requires Scipy"
)(
test_case
)
def
require_tokenizers
(
test_case
):
"""
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
"""
if
not
is_tokenizers_available
():
return
unittest
.
skip
(
"test requires tokenizers"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_tokenizers_available
(),
"test requires tokenizers"
)(
test_case
)
def
require_pandas
(
test_case
):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
if
not
is_pandas_available
():
return
unittest
.
skip
(
"test requires pandas"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_pandas_available
(),
"test requires pandas"
)(
test_case
)
def
require_pytesseract
(
test_case
):
"""
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
"""
if
not
is_pytesseract_available
():
return
unittest
.
skip
(
"test requires PyTesseract"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_pytesseract_available
(),
"test requires PyTesseract"
)(
test_case
)
def
require_scatter
(
test_case
):
...
...
@@ -406,10 +355,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
if
not
is_scatter_available
():
return
unittest
.
skip
(
"test requires PyTorch Scatter"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_scatter_available
(),
"test requires PyTorch Scatter"
)(
test_case
)
def
require_pytorch_quantization
(
test_case
):
...
...
@@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case):
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
Quantization Toolkit isn't installed.
"""
if
not
is_pytorch_quantization_available
():
return
unittest
.
skip
(
"test requires PyTorch Quantization Toolkit"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_pytorch_quantization_available
(),
"test requires PyTorch Quantization Toolkit"
)(
test_case
)
def
require_vision
(
test_case
):
...
...
@@ -428,30 +373,21 @@ def require_vision(test_case):
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
installed.
"""
if
not
is_vision_available
():
return
unittest
.
skip
(
"test requires vision"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_vision_available
(),
"test requires vision"
)(
test_case
)
def
require_ftfy
(
test_case
):
"""
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
"""
if
not
is_ftfy_available
():
return
unittest
.
skip
(
"test requires ftfy"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_ftfy_available
(),
"test requires ftfy"
)(
test_case
)
def
require_spacy
(
test_case
):
"""
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
"""
if
not
is_spacy_available
():
return
unittest
.
skip
(
"test requires spacy"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_spacy_available
(),
"test requires spacy"
)(
test_case
)
def
require_torch_multi_gpu
(
test_case
):
...
...
@@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case):
import
torch
if
torch
.
cuda
.
device_count
()
<
2
:
return
unittest
.
skip
(
"test requires multiple GPUs"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
torch
.
cuda
.
device_count
()
>
1
,
"test requires multiple GPUs"
)(
test_case
)
def
require_torch_non_multi_gpu
(
test_case
):
...
...
@@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case):
import
torch
if
torch
.
cuda
.
device_count
()
>
1
:
return
unittest
.
skip
(
"test requires 0 or 1 GPU"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
torch
.
cuda
.
device_count
()
<
2
,
"test requires 0 or 1 GPU"
)(
test_case
)
def
require_torch_up_to_2_gpus
(
test_case
):
...
...
@@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case):
import
torch
if
torch
.
cuda
.
device_count
()
>
2
:
return
unittest
.
skip
(
"test requires 0 or 1 or 2 GPUs"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
torch
.
cuda
.
device_count
()
<
3
,
"test requires 0 or 1 or 2 GPUs"
)(
test_case
)
def
require_torch_tpu
(
test_case
):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if
not
is_torch_tpu_available
():
return
unittest
.
skip
(
"test requires PyTorch TPU"
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_torch_tpu_available
(),
"test requires PyTorch TPU"
)(
test_case
)
if
is_torch_available
():
...
...
@@ -533,42 +457,31 @@ else:
def
require_torch_gpu
(
test_case
):
"""Decorator marking a test that requires CUDA and PyTorch."""
if
torch_device
!=
"cuda"
:
return
unittest
.
skip
(
"test requires CUDA"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
torch_device
==
"cuda"
,
"test requires CUDA"
)(
test_case
)
def
require_torch_bf16
(
test_case
):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if
not
is_torch_bf16_available
():
return
unittest
.
skip
(
"test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_torch_bf16_available
(),
"test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
)(
test_case
)
def
require_torch_tf32
(
test_case
):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if
not
is_torch_tf32_available
():
return
unittest
.
skip
(
"test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_torch_tf32_available
(),
"test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
)(
test_case
)
def
require_detectron2
(
test_case
):
"""Decorator marking a test that requires detectron2."""
if
not
is_detectron2_available
():
return
unittest
.
skip
(
"test requires `detectron2`"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_detectron2_available
(),
"test requires `detectron2`"
)(
test_case
)
def
require_faiss
(
test_case
):
"""Decorator marking a test that requires faiss."""
if
not
is_faiss_available
():
return
unittest
.
skip
(
"test requires `faiss`"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_faiss_available
(),
"test requires `faiss`"
)(
test_case
)
def
require_optuna
(
test_case
):
...
...
@@ -578,10 +491,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed.
"""
if
not
is_optuna_available
():
return
unittest
.
skip
(
"test requires optuna"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_optuna_available
(),
"test requires optuna"
)(
test_case
)
def
require_ray
(
test_case
):
...
...
@@ -591,10 +501,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed.
"""
if
not
is_ray_available
():
return
unittest
.
skip
(
"test requires Ray/tune"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_ray_available
(),
"test requires Ray/tune"
)(
test_case
)
def
require_sigopt
(
test_case
):
...
...
@@ -604,10 +511,7 @@ def require_sigopt(test_case):
These tests are skipped when SigOpt isn't installed.
"""
if
not
is_sigopt_available
():
return
unittest
.
skip
(
"test requires SigOpt"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_sigopt_available
(),
"test requires SigOpt"
)(
test_case
)
def
require_wandb
(
test_case
):
...
...
@@ -617,10 +521,7 @@ def require_wandb(test_case):
These tests are skipped when wandb isn't installed.
"""
if
not
is_wandb_available
():
return
unittest
.
skip
(
"test requires wandb"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_wandb_available
(),
"test requires wandb"
)(
test_case
)
def
require_soundfile
(
test_case
):
...
...
@@ -630,80 +531,56 @@ def require_soundfile(test_case):
These tests are skipped when soundfile isn't installed.
"""
if
not
is_soundfile_availble
():
return
unittest
.
skip
(
"test requires soundfile"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_soundfile_availble
(),
"test requires soundfile"
)(
test_case
)
def
require_deepspeed
(
test_case
):
"""
Decorator marking a test that requires deepspeed
"""
if
not
is_deepspeed_available
():
return
unittest
.
skip
(
"test requires deepspeed"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_deepspeed_available
(),
"test requires deepspeed"
)(
test_case
)
def
require_fairscale
(
test_case
):
"""
Decorator marking a test that requires fairscale
"""
if
not
is_fairscale_available
():
return
unittest
.
skip
(
"test requires fairscale"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_fairscale_available
(),
"test requires fairscale"
)(
test_case
)
def
require_apex
(
test_case
):
"""
Decorator marking a test that requires apex
"""
if
not
is_apex_available
():
return
unittest
.
skip
(
"test requires apex"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_apex_available
(),
"test requires apex"
)(
test_case
)
def
require_bitsandbytes
(
test_case
):
"""
Decorator for bits and bytes (bnb) dependency
"""
if
not
is_bitsandbytes_available
():
return
unittest
.
skip
(
"test requires bnb"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_bitsandbytes_available
(),
"test requires bnb"
)(
test_case
)
def
require_phonemizer
(
test_case
):
"""
Decorator marking a test that requires phonemizer
"""
if
not
is_phonemizer_available
():
return
unittest
.
skip
(
"test requires phonemizer"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_phonemizer_available
(),
"test requires phonemizer"
)(
test_case
)
def
require_pyctcdecode
(
test_case
):
"""
Decorator marking a test that requires pyctcdecode
"""
if
not
is_pyctcdecode_available
():
return
unittest
.
skip
(
"test requires pyctcdecode"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_pyctcdecode_available
(),
"test requires pyctcdecode"
)(
test_case
)
def
require_librosa
(
test_case
):
"""
Decorator marking a test that requires librosa
"""
if
not
is_librosa_available
():
return
unittest
.
skip
(
"test requires librosa"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
is_librosa_available
(),
"test requires librosa"
)(
test_case
)
def
cmd_exists
(
cmd
):
...
...
@@ -714,10 +591,7 @@ def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
if
not
cmd_exists
(
"/usr/bin/time"
):
return
unittest
.
skip
(
"test requires /usr/bin/time"
)(
test_case
)
else
:
return
test_case
return
unittest
.
skipUnless
(
cmd_exists
(
"/usr/bin/time"
),
"test requires /usr/bin/time"
)(
test_case
)
def
get_gpu_count
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment