Unverified Commit 1de7dc74 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Skip tests properly (#31308)

* Skip tests properly

* [test_all]

* Add 'reason' as kwarg for skipTest

* [test_all] Fix up

* [test_all]
parent 1f9f57ab
...@@ -862,7 +862,7 @@ Code, der fehlerhaft ist, einen schlechten Zustand verursacht, der sich auf ande ...@@ -862,7 +862,7 @@ Code, der fehlerhaft ist, einen schlechten Zustand verursacht, der sich auf ande
- Hier sehen Sie, wie Sie einen ganzen Test bedingungslos überspringen können: - Hier sehen Sie, wie Sie einen ganzen Test bedingungslos überspringen können:
```python no-style ```python no-style
@unittest.skip("this bug needs to be fixed") @unittest.skip(reason="this bug needs to be fixed")
def test_feature_x(): def test_feature_x():
``` ```
......
...@@ -881,7 +881,7 @@ code that's buggy causes some bad state that will affect other tests, do not use ...@@ -881,7 +881,7 @@ code that's buggy causes some bad state that will affect other tests, do not use
- Here is how to skip whole test unconditionally: - Here is how to skip whole test unconditionally:
```python no-style ```python no-style
@unittest.skip("this bug needs to be fixed") @unittest.skip(reason="this bug needs to be fixed")
def test_feature_x(): def test_feature_x():
``` ```
......
...@@ -809,7 +809,7 @@ with ExtendSysPath(f"{bindir}/.."): ...@@ -809,7 +809,7 @@ with ExtendSysPath(f"{bindir}/.."):
```python no-style ```python no-style
@unittest.skip("this bug needs to be fixed") @unittest.skip(reason="this bug needs to be fixed")
def test_feature_x(): def test_feature_x():
``` ```
...@@ -1211,4 +1211,3 @@ cmd_that_may_fail || true ...@@ -1211,4 +1211,3 @@ cmd_that_may_fail || true
- [Github Actions:](https://github.com/actions/toolkit/issues/399) - [Github Actions:](https://github.com/actions/toolkit/issues/399)
- [CircleCI:](https://ideas.circleci.com/ideas/CCI-I-344) - [CircleCI:](https://ideas.circleci.com/ideas/CCI-I-344)
...@@ -847,7 +847,7 @@ with ExtendSysPath(f"{bindir}/.."): ...@@ -847,7 +847,7 @@ with ExtendSysPath(f"{bindir}/.."):
- 전체 테스트를 무조건 건너뛰려면 다음과 같이 할 수 있습니다: - 전체 테스트를 무조건 건너뛰려면 다음과 같이 할 수 있습니다:
```python no-style ```python no-style
@unittest.skip("this bug needs to be fixed") @unittest.skip(reason="this bug needs to be fixed")
def test_feature_x(): def test_feature_x():
``` ```
......
...@@ -226,7 +226,7 @@ def is_pt_tf_cross_test(test_case): ...@@ -226,7 +226,7 @@ def is_pt_tf_cross_test(test_case):
""" """
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available(): if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
return unittest.skip("test is PT+TF test")(test_case) return unittest.skip(reason="test is PT+TF test")(test_case)
else: else:
try: try:
import pytest # We don't need a hard dependency on pytest in the main library import pytest # We don't need a hard dependency on pytest in the main library
...@@ -245,7 +245,7 @@ def is_pt_flax_cross_test(test_case): ...@@ -245,7 +245,7 @@ def is_pt_flax_cross_test(test_case):
""" """
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available(): if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
return unittest.skip("test is PT+FLAX test")(test_case) return unittest.skip(reason="test is PT+FLAX test")(test_case)
else: else:
try: try:
import pytest # We don't need a hard dependency on pytest in the main library import pytest # We don't need a hard dependency on pytest in the main library
...@@ -262,7 +262,7 @@ def is_staging_test(test_case): ...@@ -262,7 +262,7 @@ def is_staging_test(test_case):
Those tests will run using the staging environment of huggingface.co instead of the real model hub. Those tests will run using the staging environment of huggingface.co instead of the real model hub.
""" """
if not _run_staging: if not _run_staging:
return unittest.skip("test is staging test")(test_case) return unittest.skip(reason="test is staging test")(test_case)
else: else:
try: try:
import pytest # We don't need a hard dependency on pytest in the main library import pytest # We don't need a hard dependency on pytest in the main library
...@@ -278,7 +278,7 @@ def is_pipeline_test(test_case): ...@@ -278,7 +278,7 @@ def is_pipeline_test(test_case):
skipped. skipped.
""" """
if not _run_pipeline_tests: if not _run_pipeline_tests:
return unittest.skip("test is pipeline test")(test_case) return unittest.skip(reason="test is pipeline test")(test_case)
else: else:
try: try:
import pytest # We don't need a hard dependency on pytest in the main library import pytest # We don't need a hard dependency on pytest in the main library
...@@ -293,7 +293,7 @@ def is_agent_test(test_case): ...@@ -293,7 +293,7 @@ def is_agent_test(test_case):
Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped. Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
""" """
if not _run_agent_tests: if not _run_agent_tests:
return unittest.skip("test is an agent test")(test_case) return unittest.skip(reason="test is an agent test")(test_case)
else: else:
try: try:
import pytest # We don't need a hard dependency on pytest in the main library import pytest # We don't need a hard dependency on pytest in the main library
...@@ -321,7 +321,7 @@ def tooslow(test_case): ...@@ -321,7 +321,7 @@ def tooslow(test_case):
these will not be tested by the CI. these will not be tested by the CI.
""" """
return unittest.skip("test is too slow")(test_case) return unittest.skip(reason="test is too slow")(test_case)
def custom_tokenizers(test_case): def custom_tokenizers(test_case):
...@@ -709,7 +709,7 @@ def require_torch_multi_gpu(test_case): ...@@ -709,7 +709,7 @@ def require_torch_multi_gpu(test_case):
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
import torch import torch
...@@ -723,7 +723,7 @@ def require_torch_multi_accelerator(test_case): ...@@ -723,7 +723,7 @@ def require_torch_multi_accelerator(test_case):
multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator" multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")( return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
test_case test_case
...@@ -735,7 +735,7 @@ def require_torch_non_multi_gpu(test_case): ...@@ -735,7 +735,7 @@ def require_torch_non_multi_gpu(test_case):
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
import torch import torch
...@@ -747,7 +747,7 @@ def require_torch_non_multi_accelerator(test_case): ...@@ -747,7 +747,7 @@ def require_torch_non_multi_accelerator(test_case):
Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch). Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case) return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
...@@ -757,7 +757,7 @@ def require_torch_up_to_2_gpus(test_case): ...@@ -757,7 +757,7 @@ def require_torch_up_to_2_gpus(test_case):
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch). Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
import torch import torch
...@@ -769,7 +769,7 @@ def require_torch_up_to_2_accelerators(test_case): ...@@ -769,7 +769,7 @@ def require_torch_up_to_2_accelerators(test_case):
Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch). Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip(reason="test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators") return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")
(test_case) (test_case)
...@@ -806,7 +806,7 @@ def require_torch_multi_npu(test_case): ...@@ -806,7 +806,7 @@ def require_torch_multi_npu(test_case):
To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu" To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
""" """
if not is_torch_npu_available(): if not is_torch_npu_available():
return unittest.skip("test requires PyTorch NPU")(test_case) return unittest.skip(reason="test requires PyTorch NPU")(test_case)
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
...@@ -830,7 +830,7 @@ def require_torch_multi_xpu(test_case): ...@@ -830,7 +830,7 @@ def require_torch_multi_xpu(test_case):
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
""" """
if not is_torch_xpu_available(): if not is_torch_xpu_available():
return unittest.skip("test requires PyTorch XPU")(test_case) return unittest.skip(reason="test requires PyTorch XPU")(test_case)
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
...@@ -1078,7 +1078,7 @@ def require_bitsandbytes(test_case): ...@@ -1078,7 +1078,7 @@ def require_bitsandbytes(test_case):
except ImportError: except ImportError:
return test_case return test_case
else: else:
return unittest.skip("test requires bitsandbytes and torch")(test_case) return unittest.skip(reason="test requires bitsandbytes and torch")(test_case)
def require_optimum(test_case): def require_optimum(test_case):
......
...@@ -108,13 +108,13 @@ def require_deepspeed_aio(test_case): ...@@ -108,13 +108,13 @@ def require_deepspeed_aio(test_case):
Decorator marking a test that requires deepspeed aio (nvme) Decorator marking a test that requires deepspeed aio (nvme)
""" """
if not is_deepspeed_available(): if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case) return unittest.skip(reason="test requires deepspeed")(test_case)
import deepspeed import deepspeed
from deepspeed.ops.aio import AsyncIOBuilder from deepspeed.ops.aio import AsyncIOBuilder
if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
return unittest.skip("test requires deepspeed async-io")(test_case) return unittest.skip(reason="test requires deepspeed async-io")(test_case)
else: else:
return test_case return test_case
...@@ -643,7 +643,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -643,7 +643,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
# print(trainer.model.b.item()) # print(trainer.model.b.item())
# need to investigate at some point # need to investigate at some point
if (stage == ZERO3 and dtype == FP16) or (dtype == BF16): if (stage == ZERO3 and dtype == FP16) or (dtype == BF16):
return self.skipTest(reason="When using zero3/fp16 or any/bf16 the optimizer seems run oddly")
# it's enough that train didn't fail for this test, but we must check that # it's enough that train didn't fail for this test, but we must check that
# optimizer/scheduler didn't run (since if it did this test isn't testing the right thing) # optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
...@@ -795,7 +795,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -795,7 +795,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and # ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out. # also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER: if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return self.skipTest(reason="hf_optim + hf_scheduler resumes with the correct states but slowly diverges")
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage) ds_config_dict = self.get_config_dict(stage)
...@@ -1113,7 +1113,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -1113,7 +1113,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_inference(self, dtype): def test_inference(self, dtype):
if dtype == "bf16" and not is_torch_bf16_available_on_device(torch_device): if dtype == "bf16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest("test requires bfloat16 hardware support") self.skipTest(reason="test requires bfloat16 hardware support")
# this is just inference, so no optimizer should be loaded # this is just inference, so no optimizer should be loaded
# it only works for z3 (makes no sense with z1-z2) # it only works for z3 (makes no sense with z1-z2)
......
...@@ -80,7 +80,7 @@ class TestTrainerExt(TestCasePlus): ...@@ -80,7 +80,7 @@ class TestTrainerExt(TestCasePlus):
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
if not do_eval: if not do_eval:
return self.skipTest(reason="do_eval is False")
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
......
...@@ -463,9 +463,9 @@ class GenerationTesterMixin: ...@@ -463,9 +463,9 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes") self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -625,9 +625,9 @@ class GenerationTesterMixin: ...@@ -625,9 +625,9 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes") self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
...@@ -667,7 +667,7 @@ class GenerationTesterMixin: ...@@ -667,7 +667,7 @@ class GenerationTesterMixin:
def test_model_parallel_beam_search(self): def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if "xpu" in torch_device: if "xpu" in torch_device:
return unittest.skip("device_map='auto' does not work with XPU devices") return unittest.skip(reason="device_map='auto' does not work with XPU devices")
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
...@@ -765,7 +765,7 @@ class GenerationTesterMixin: ...@@ -765,7 +765,7 @@ class GenerationTesterMixin:
# if no bos token id => cannot generate from None # if no bos token id => cannot generate from None
if config.bos_token_id is None: if config.bos_token_id is None:
return self.skipTest(reason="bos_token_id is None")
# hack in case they are equal, otherwise the attn mask will be [0] # hack in case they are equal, otherwise the attn mask will be [0]
if config.bos_token_id == config.pad_token_id: if config.bos_token_id == config.pad_token_id:
...@@ -982,17 +982,17 @@ class GenerationTesterMixin: ...@@ -982,17 +982,17 @@ class GenerationTesterMixin:
def test_contrastive_generate(self): def test_contrastive_generate(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation") self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1009,17 +1009,17 @@ class GenerationTesterMixin: ...@@ -1009,17 +1009,17 @@ class GenerationTesterMixin:
def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation") self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1045,18 +1045,18 @@ class GenerationTesterMixin: ...@@ -1045,18 +1045,18 @@ class GenerationTesterMixin:
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation") self.skipTest(reason="Stateful models don't support contrastive search generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest("TODO: fix me") self.skipTest(reason="TODO: fix me")
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1087,9 +1087,9 @@ class GenerationTesterMixin: ...@@ -1087,9 +1087,9 @@ class GenerationTesterMixin:
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("May fix in the future: need custom cache handling") self.skipTest(reason="May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in [ for model_name in [
...@@ -1102,7 +1102,7 @@ class GenerationTesterMixin: ...@@ -1102,7 +1102,7 @@ class GenerationTesterMixin:
"jamba", "jamba",
] ]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2) config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output # batch_size=1 is ok, but batch_size>1 will cause non-identical output
...@@ -1135,9 +1135,9 @@ class GenerationTesterMixin: ...@@ -1135,9 +1135,9 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation") self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in [ for model_name in [
...@@ -1151,14 +1151,14 @@ class GenerationTesterMixin: ...@@ -1151,14 +1151,14 @@ class GenerationTesterMixin:
"clvp", "clvp",
] ]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache # enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1206,9 +1206,9 @@ class GenerationTesterMixin: ...@@ -1206,9 +1206,9 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation") self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in [ for model_name in [
...@@ -1222,14 +1222,14 @@ class GenerationTesterMixin: ...@@ -1222,14 +1222,14 @@ class GenerationTesterMixin:
"clvp", "clvp",
] ]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache # enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1268,9 +1268,9 @@ class GenerationTesterMixin: ...@@ -1268,9 +1268,9 @@ class GenerationTesterMixin:
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful: if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation") self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest(reason="Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in [ for model_name in [
...@@ -1284,14 +1284,14 @@ class GenerationTesterMixin: ...@@ -1284,14 +1284,14 @@ class GenerationTesterMixin:
"clvp", "clvp",
] ]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache # enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1436,7 +1436,7 @@ class GenerationTesterMixin: ...@@ -1436,7 +1436,7 @@ class GenerationTesterMixin:
# If it doesn't support cache, pass the test # If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
if "use_cache" not in inputs: if "use_cache" not in inputs:
...@@ -1445,7 +1445,7 @@ class GenerationTesterMixin: ...@@ -1445,7 +1445,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest("This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
num_hidden_layers = ( num_hidden_layers = (
getattr(config, "decoder_layers", None) getattr(config, "decoder_layers", None)
...@@ -1553,14 +1553,14 @@ class GenerationTesterMixin: ...@@ -1553,14 +1553,14 @@ class GenerationTesterMixin:
# Tests that we can continue generating from past key values, returned from a previous `generate` call # Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
self.skipTest("Won't fix: old model with unique inputs/caches/other") self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility") self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest(reason="This model doesn't support caching")
# Let's make it always: # Let's make it always:
# 1. use cache (for obvious reasons) # 1. use cache (for obvious reasons)
...@@ -1582,7 +1582,7 @@ class GenerationTesterMixin: ...@@ -1582,7 +1582,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs) outputs = model(**inputs)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest("This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
...@@ -1632,7 +1632,7 @@ class GenerationTesterMixin: ...@@ -1632,7 +1632,7 @@ class GenerationTesterMixin:
# 👉 tests with and without sampling so we can cover the most common use cases. # 👉 tests with and without sampling so we can cover the most common use cases.
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class: if not model_class._supports_cache_class:
self.skipTest("This model does not support the new cache format") self.skipTest(reason="This model does not support the new cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
...@@ -1689,7 +1689,7 @@ class GenerationTesterMixin: ...@@ -1689,7 +1689,7 @@ class GenerationTesterMixin:
def test_generate_with_quant_cache(self): def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache: if not model_class._supports_quantized_cache:
self.skipTest("This model does not support the quantized cache format") self.skipTest(reason="This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
......
...@@ -67,7 +67,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -67,7 +67,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
return self.skipTest(reason="test_rust_tokenizer is set to False")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer() rust_tokenizer = self.get_rust_tokenizer()
......
...@@ -23,7 +23,6 @@ import requests ...@@ -23,7 +23,6 @@ import requests
from transformers import AlignConfig, AlignProcessor, AlignTextConfig, AlignVisionConfig from transformers import AlignConfig, AlignProcessor, AlignTextConfig, AlignVisionConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flax_available,
require_torch, require_torch,
require_vision, require_vision,
slow, slow,
...@@ -56,10 +55,6 @@ if is_vision_available(): ...@@ -56,10 +55,6 @@ if is_vision_available():
from PIL import Image from PIL import Image
if is_flax_available():
pass
class AlignVisionModelTester: class AlignVisionModelTester:
def __init__( def __init__(
self, self,
...@@ -215,9 +210,11 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -215,9 +210,11 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip
def test_training(self): def test_training(self):
pass pass
@unittest.skip
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
...@@ -355,9 +352,11 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -355,9 +352,11 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip
def test_training(self): def test_training(self):
pass pass
@unittest.skip
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
...@@ -518,7 +517,7 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -518,7 +517,7 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def _create_and_check_torchscript(self, config, inputs_dict): def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript: if not self.test_torchscript:
return self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True configs_no_init.torchscript = True
......
...@@ -178,9 +178,11 @@ class AltCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -178,9 +178,11 @@ class AltCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip
def test_training(self): def test_training(self):
pass pass
@unittest.skip
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
...@@ -309,7 +311,7 @@ class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -309,7 +311,7 @@ class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = False test_head_masking = False
# TODO (@SunMarc): Fix me # TODO (@SunMarc): Fix me
@unittest.skip("It's broken.") @unittest.skip(reason="It's broken.")
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
super().test_resize_tokens_embeddings() super().test_resize_tokens_embeddings()
...@@ -324,9 +326,11 @@ class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -324,9 +326,11 @@ class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip
def test_training(self): def test_training(self):
pass pass
@unittest.skip
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
...@@ -487,7 +491,7 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -487,7 +491,7 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def _create_and_check_torchscript(self, config, inputs_dict): def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript: if not self.test_torchscript:
return self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True configs_no_init.torchscript = True
......
...@@ -754,7 +754,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -754,7 +754,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
@unittest.skip("FineModel relies on codebook idx and does not return same logits") @unittest.skip(reason="FineModel relies on codebook idx and does not return same logits")
def test_inputs_embeds_matches_input_ids(self): def test_inputs_embeds_matches_input_ids(self):
pass pass
...@@ -826,7 +826,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -826,7 +826,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
# resizing tokens_embeddings of a ModuleList # resizing tokens_embeddings of a ModuleList
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
return self.skipTest(reason="test_resize_embeddings is False")
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config = copy.deepcopy(original_config) config = copy.deepcopy(original_config)
...@@ -877,7 +877,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -877,7 +877,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
# resizing tokens_embeddings of a ModuleList # resizing tokens_embeddings of a ModuleList
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
return self.skipTest(reason="test_resize_embeddings is False")
original_config.tie_word_embeddings = False original_config.tie_word_embeddings = False
...@@ -931,7 +931,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -931,7 +931,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
def test_flash_attn_2_inference_equivalence(self): def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(reason="Model does not support flash_attention_2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -988,7 +988,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -988,7 +988,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
def test_flash_attn_2_inference_equivalence_right_padding(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(reason="Model does not support flash_attention_2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
......
...@@ -1515,9 +1515,10 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un ...@@ -1515,9 +1515,10 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return return
@unittest.skip
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
...@@ -147,6 +147,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): ...@@ -147,6 +147,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item()) self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())
@unittest.skip
def test_pretokenized_inputs(self): def test_pretokenized_inputs(self):
pass pass
......
...@@ -75,7 +75,7 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -75,7 +75,7 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
return self.skipTest(reason="test_rust_tokenizer is set to False")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer() rust_tokenizer = self.get_rust_tokenizer()
......
...@@ -301,7 +301,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -301,7 +301,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
return self.skipTest(reason="model_tester.is_training is set to False")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
...@@ -325,7 +325,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -325,7 +325,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training: if not self.model_tester.is_training:
return self.skipTest(reason="model_tester.is_training is set to False")
config.use_cache = False config.use_cache = False
config.return_dict = True config.return_dict = True
......
...@@ -614,7 +614,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -614,7 +614,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
# BertForMultipleChoice behaves incorrectly in JIT environments. # BertForMultipleChoice behaves incorrectly in JIT environments.
if model_class == BertForMultipleChoice: if model_class == BertForMultipleChoice:
return self.skipTest(reason="BertForMultipleChoice behaves incorrectly in JIT environments.")
config.torchscript = True config.torchscript = True
model = model_class(config=config) model = model_class(config=config)
......
...@@ -79,7 +79,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -79,7 +79,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
return self.skipTest(reason="test_rust_tokenizer is set to False")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer() rust_tokenizer = self.get_rust_tokenizer()
......
...@@ -716,7 +716,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase): ...@@ -716,7 +716,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
""" """
if not self.test_attention_probs: if not self.test_attention_probs:
return self.skip("test_attention_probs is set to False")
model = BigBirdModel.from_pretrained( model = BigBirdModel.from_pretrained(
"google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16 "google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16
......
...@@ -63,7 +63,7 @@ class BigBirdTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -63,7 +63,7 @@ class BigBirdTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
return self.skipTest(reason="test_rust_tokenizer is set to False")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer() rust_tokenizer = self.get_rust_tokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment