Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -31,6 +31,7 @@ from ...test_modeling_common import ( ...@@ -31,6 +31,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -320,8 +321,17 @@ class SEWDModelTester: ...@@ -320,8 +321,17 @@ class SEWDModelTester:
@require_torch @require_torch
class SEWDModelTest(ModelTesterMixin, unittest.TestCase): class SEWDModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SEWDForCTC, SEWDModel, SEWDForSequenceClassification) if is_torch_available() else () all_model_classes = (SEWDForCTC, SEWDModel, SEWDForSequenceClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{
"audio-classification": SEWDForSequenceClassification,
"automatic-speech-recognition": SEWDForCTC,
"feature-extraction": SEWDModel,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
......
...@@ -35,6 +35,7 @@ from transformers.utils import cached_property ...@@ -35,6 +35,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -267,9 +268,14 @@ class Speech2TextModelTester: ...@@ -267,9 +268,14 @@ class Speech2TextModelTester:
@require_torch @require_torch
class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{"automatic-speech-recognition": Speech2TextForConditionalGeneration, "feature-extraction": Speech2TextModel}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_tf_available ...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_tf_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -209,9 +210,10 @@ class TFSpeech2TextModelTester: ...@@ -209,9 +210,10 @@ class TFSpeech2TextModelTester:
@require_tf @require_tf
class TFSpeech2TextModelTest(TFModelTesterMixin, unittest.TestCase): class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFSpeech2TextModel, TFSpeech2TextForConditionalGeneration) if is_tf_available() else () all_model_classes = (TFSpeech2TextModel, TFSpeech2TextForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFSpeech2TextForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFSpeech2TextForConditionalGeneration,) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFSpeech2TextModel} if is_tf_available() else {}
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import is_torch_available, require_torch, torch_ ...@@ -22,6 +22,7 @@ from transformers.testing_utils import is_torch_available, require_torch, torch_
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -176,9 +177,12 @@ class Speech2Text2StandaloneDecoderModelTester: ...@@ -176,9 +177,12 @@ class Speech2Text2StandaloneDecoderModelTester:
@require_torch @require_torch
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class Speech2Text2StandaloneDecoderModelTest(
ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase
):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else () all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = {"text-generation": Speech2Text2ForCausalLM} if is_torch_available() else {}
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -39,6 +39,7 @@ from ...test_modeling_common import ( ...@@ -39,6 +39,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -160,8 +161,13 @@ class SpeechT5ModelTester: ...@@ -160,8 +161,13 @@ class SpeechT5ModelTester:
@require_torch @require_torch
class SpeechT5ModelTest(ModelTesterMixin, unittest.TestCase): class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SpeechT5Model,) if is_torch_available() else () all_model_classes = (SpeechT5Model,) if is_torch_available() else ()
pipeline_model_mapping = (
{"automatic-speech-recognition": SpeechT5ForSpeechToText, "feature-extraction": SpeechT5Model}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, s ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, s
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -207,7 +208,7 @@ class SplinterModelTester: ...@@ -207,7 +208,7 @@ class SplinterModelTester:
@require_torch @require_torch
class SplinterModelTest(ModelTesterMixin, unittest.TestCase): class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
SplinterModel, SplinterModel,
...@@ -217,6 +218,11 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -217,6 +218,11 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": SplinterModel, "question-answering": SplinterForQuestionAnswering}
if is_torch_available()
else {}
)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -214,7 +215,7 @@ class SqueezeBertModelTester(object): ...@@ -214,7 +215,7 @@ class SqueezeBertModelTester(object):
@require_torch @require_torch
class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): class SqueezeBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
SqueezeBertModel, SqueezeBertModel,
...@@ -227,6 +228,18 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -227,6 +228,18 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
pipeline_model_mapping = (
{
"feature-extraction": SqueezeBertModel,
"fill-mask": SqueezeBertForMaskedLM,
"question-answering": SqueezeBertForQuestionAnswering,
"text-classification": SqueezeBertForSequenceClassification,
"token-classification": SqueezeBertForTokenClassification,
"zero-shot": SqueezeBertForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = False test_head_masking = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -218,7 +219,7 @@ class SwinModelTester: ...@@ -218,7 +219,7 @@ class SwinModelTester:
@require_torch @require_torch
class SwinModelTest(ModelTesterMixin, unittest.TestCase): class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
SwinModel, SwinModel,
...@@ -229,6 +230,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -229,6 +230,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail ...@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -176,7 +177,7 @@ class TFSwinModelTester: ...@@ -176,7 +177,7 @@ class TFSwinModelTester:
@require_tf @require_tf
class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase): class TFSwinModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFSwinModel, TFSwinModel,
...@@ -186,6 +187,11 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -186,6 +187,11 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": TFSwinModel, "image-classification": TFSwinForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -22,6 +22,7 @@ from transformers.utils import is_torch_available, is_vision_available ...@@ -22,6 +22,7 @@ from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -155,8 +156,9 @@ class Swin2SRModelTester: ...@@ -155,8 +156,9 @@ class Swin2SRModelTester:
@require_torch @require_torch
class Swin2SRModelTest(ModelTesterMixin, unittest.TestCase): class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else () all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": Swin2SRModel} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -170,10 +171,15 @@ class Swinv2ModelTester: ...@@ -170,10 +171,15 @@ class Swinv2ModelTester:
@require_torch @require_torch
class Swinv2ModelTest(ModelTesterMixin, unittest.TestCase): class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else () (Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else ()
) )
pipeline_model_mapping = (
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tokenizers, require_torch, requir ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tokenizers, require_torch, requir
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -546,11 +547,21 @@ class SwitchTransformersModelTester: ...@@ -546,11 +547,21 @@ class SwitchTransformersModelTester:
@require_torch @require_torch
class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else ()
) )
all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": SwitchTransformersForConditionalGeneration,
"feature-extraction": SwitchTransformersModel,
"summarization": SwitchTransformersForConditionalGeneration,
"text2text-generation": SwitchTransformersForConditionalGeneration,
}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
......
...@@ -32,6 +32,7 @@ from transformers.utils import cached_property ...@@ -32,6 +32,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -518,9 +519,19 @@ class T5ModelTester: ...@@ -518,9 +519,19 @@ class T5ModelTester:
@require_torch @require_torch
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": T5ForConditionalGeneration,
"feature-extraction": T5Model,
"summarization": T5ForConditionalGeneration,
"text2text-generation": T5ForConditionalGeneration,
}
if is_torch_available()
else {}
)
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -21,6 +21,7 @@ from transformers.utils import cached_property ...@@ -21,6 +21,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -239,10 +240,20 @@ class TFT5ModelTester: ...@@ -239,10 +240,20 @@ class TFT5ModelTester:
@require_tf @require_tf
class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
is_encoder_decoder = True is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"conversational": TFT5ForConditionalGeneration,
"feature-extraction": TFT5Model,
"summarization": TFT5ForConditionalGeneration,
"text2text-generation": TFT5ForConditionalGeneration,
}
if is_tf_available()
else {}
)
test_onnx = False test_onnx = False
def setUp(self): def setUp(self):
......
...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_timm, require_vision, slow, torch ...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_timm, require_vision, slow, torch
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_timm_available(): if is_timm_available():
...@@ -175,7 +176,7 @@ class TableTransformerModelTester: ...@@ -175,7 +176,7 @@ class TableTransformerModelTester:
@require_timm @require_timm
class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TableTransformerModel, TableTransformerModel,
...@@ -184,6 +185,11 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes ...@@ -184,6 +185,11 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes
if is_timm_available() if is_timm_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
if is_timm_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_torchscript = False test_torchscript = False
test_pruning = False test_pruning = False
......
...@@ -37,6 +37,7 @@ from transformers.utils import cached_property ...@@ -37,6 +37,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -408,7 +409,7 @@ class TapasModelTester: ...@@ -408,7 +409,7 @@ class TapasModelTester:
@require_torch @require_torch
class TapasModelTest(ModelTesterMixin, unittest.TestCase): class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TapasModel, TapasModel,
...@@ -419,6 +420,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -419,6 +420,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
pipeline_model_mapping = (
{
"feature-extraction": TapasModel,
"fill-mask": TapasForMaskedLM,
"table-question-answering": TapasForQuestionAnswering,
"text-classification": TapasForSequenceClassification,
"zero-shot": TapasForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = False test_head_masking = False
......
...@@ -39,6 +39,7 @@ from transformers.utils import cached_property ...@@ -39,6 +39,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -418,7 +419,7 @@ class TFTapasModelTester: ...@@ -418,7 +419,7 @@ class TFTapasModelTester:
@require_tensorflow_probability @require_tensorflow_probability
@require_tf @require_tf
class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase): class TFTapasModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFTapasModel, TFTapasModel,
...@@ -429,6 +430,16 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -429,6 +430,16 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFTapasModel,
"fill-mask": TFTapasForMaskedLM,
"text-classification": TFTapasForSequenceClassification,
"zero-shot": TFTapasForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -25,6 +25,7 @@ from transformers.testing_utils import is_flaky, require_torch, slow, torch_devi ...@@ -25,6 +25,7 @@ from transformers.testing_utils import is_flaky, require_torch, slow, torch_devi
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
TOLERANCE = 1e-4 TOLERANCE = 1e-4
...@@ -172,11 +173,12 @@ class TimeSeriesTransformerModelTester: ...@@ -172,11 +173,12 @@ class TimeSeriesTransformerModelTester:
@require_torch @require_torch
class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase): class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TimeSeriesTransformerModel, TimeSeriesTransformerForPrediction) if is_torch_available() else () (TimeSeriesTransformerModel, TimeSeriesTransformerForPrediction) if is_torch_available() else ()
) )
all_generative_model_classes = (TimeSeriesTransformerForPrediction,) if is_torch_available() else () all_generative_model_classes = (TimeSeriesTransformerForPrediction,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": TimeSeriesTransformerModel} if is_torch_available() else {}
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
......
...@@ -29,6 +29,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -29,6 +29,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -152,13 +153,18 @@ class TimesformerModelTester: ...@@ -152,13 +153,18 @@ class TimesformerModelTester:
@require_torch @require_torch
class TimesformerModelTest(ModelTesterMixin, unittest.TestCase): class TimesformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as TimeSformer does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as TimeSformer does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (TimesformerModel, TimesformerForVideoClassification) if is_torch_available() else () all_model_classes = (TimesformerModel, TimesformerForVideoClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TimesformerModel, "video-classification": TimesformerForVideoClassification}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
......
...@@ -26,6 +26,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -26,6 +26,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, random_attention_mask from ...test_modeling_common import ModelTesterMixin, _config_zero_init, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -94,8 +95,9 @@ class TrajectoryTransformerModelTester: ...@@ -94,8 +95,9 @@ class TrajectoryTransformerModelTester:
@require_torch @require_torch
class TrajectoryTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TrajectoryTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TrajectoryTransformerModel,) if is_torch_available() else () all_model_classes = (TrajectoryTransformerModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": TrajectoryTransformerModel} if is_torch_available() else {}
# Ignoring of a failing test from GenerationTesterMixin, as the model does not use inputs_ids # Ignoring of a failing test from GenerationTesterMixin, as the model does not use inputs_ids
test_generate_without_input_ids = False test_generate_without_input_ids = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment