Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -173,7 +174,7 @@ class MobileViTModelTester: ...@@ -173,7 +174,7 @@ class MobileViTModelTester:
@require_torch @require_torch
class MobileViTModelTest(ModelTesterMixin, unittest.TestCase): class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -184,6 +185,15 @@ class MobileViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -184,6 +185,15 @@ class MobileViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": MobileViTModel,
"image-classification": MobileViTForImageClassification,
"image-segmentation": MobileViTForSemanticSegmentation,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -154,7 +155,7 @@ class TFMobileViTModelTester: ...@@ -154,7 +155,7 @@ class TFMobileViTModelTester:
@require_tf @require_tf
class TFMobileViTModelTest(TFModelTesterMixin, unittest.TestCase): class TFMobileViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -165,6 +166,11 @@ class TFMobileViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -165,6 +166,11 @@ class TFMobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": TFMobileViTModel, "image-classification": TFMobileViTForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -190,7 +191,7 @@ class MPNetModelTester: ...@@ -190,7 +191,7 @@ class MPNetModelTester:
@require_torch @require_torch
class MPNetModelTest(ModelTesterMixin, unittest.TestCase): class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
MPNetForMaskedLM, MPNetForMaskedLM,
...@@ -203,6 +204,18 @@ class MPNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -203,6 +204,18 @@ class MPNetModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": MPNetModel,
"fill-mask": MPNetForMaskedLM,
"question-answering": MPNetForQuestionAnswering,
"text-classification": MPNetForSequenceClassification,
"token-classification": MPNetForTokenClassification,
"zero-shot": MPNetForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -184,7 +185,7 @@ class TFMPNetModelTester: ...@@ -184,7 +185,7 @@ class TFMPNetModelTester:
@require_tf @require_tf
class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): class TFMPNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFMPNetForMaskedLM, TFMPNetForMaskedLM,
...@@ -197,6 +198,18 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -197,6 +198,18 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFMPNetModel,
"fill-mask": TFMPNetForMaskedLM,
"question-answering": TFMPNetForQuestionAnswering,
"text-classification": TFMPNetForSequenceClassification,
"token-classification": TFMPNetForTokenClassification,
"zero-shot": TFMPNetForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -28,6 +28,7 @@ from transformers.utils import cached_property ...@@ -28,6 +28,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -405,13 +406,28 @@ class MvpHeadTests(unittest.TestCase): ...@@ -405,13 +406,28 @@ class MvpHeadTests(unittest.TestCase):
@require_torch @require_torch
class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(MvpModel, MvpForConditionalGeneration, MvpForSequenceClassification, MvpForQuestionAnswering) (MvpModel, MvpForConditionalGeneration, MvpForSequenceClassification, MvpForQuestionAnswering)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (MvpForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (MvpForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": MvpForConditionalGeneration,
"feature-extraction": MvpModel,
"fill-mask": MvpForConditionalGeneration,
"question-answering": MvpForQuestionAnswering,
"summarization": MvpForConditionalGeneration,
"text2text-generation": MvpForConditionalGeneration,
"text-classification": MvpForSequenceClassification,
"text-generation": MvpForCausalLM,
"zero-shot": MvpForSequenceClassification,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -189,7 +190,7 @@ class NatModelTester: ...@@ -189,7 +190,7 @@ class NatModelTester:
@require_natten @require_natten
@require_torch @require_torch
class NatModelTest(ModelTesterMixin, unittest.TestCase): class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
NatModel, NatModel,
...@@ -199,6 +200,11 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -199,6 +200,11 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": NatModel, "image-classification": NatForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_torchscript = False test_torchscript = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -315,7 +316,7 @@ class NezhaModelTester: ...@@ -315,7 +316,7 @@ class NezhaModelTester:
@require_torch @require_torch
class NezhaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class NezhaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
NezhaModel, NezhaModel,
...@@ -330,6 +331,18 @@ class NezhaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -330,6 +331,18 @@ class NezhaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": NezhaModel,
"fill-mask": NezhaForMaskedLM,
"question-answering": NezhaForQuestionAnswering,
"text-classification": NezhaForSequenceClassification,
"token-classification": NezhaForTokenClassification,
"zero-shot": NezhaForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
# special case for ForPreTraining model # special case for ForPreTraining model
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -216,7 +217,7 @@ class NystromformerModelTester: ...@@ -216,7 +217,7 @@ class NystromformerModelTester:
@require_torch @require_torch
class NystromformerModelTest(ModelTesterMixin, unittest.TestCase): class NystromformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
NystromformerModel, NystromformerModel,
...@@ -229,6 +230,18 @@ class NystromformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -229,6 +230,18 @@ class NystromformerModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": NystromformerModel,
"fill-mask": NystromformerForMaskedLM,
"question-answering": NystromformerForQuestionAnswering,
"text-classification": NystromformerForSequenceClassification,
"token-classification": NystromformerForTokenClassification,
"zero-shot": NystromformerForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
......
...@@ -27,6 +27,7 @@ from transformers.utils import cached_property ...@@ -27,6 +27,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin from ...test_modeling_common import ModelTesterMixin
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -214,8 +215,9 @@ class OneFormerModelTester: ...@@ -214,8 +215,9 @@ class OneFormerModelTester:
@require_torch @require_torch
class OneFormerModelTest(ModelTesterMixin, unittest.TestCase): class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (OneFormerModel, OneFormerForUniversalSegmentation) if is_torch_available() else () all_model_classes = (OneFormerModel, OneFormerForUniversalSegmentation) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": OneFormerModel} if is_torch_available() else {}
is_encoder_decoder = False is_encoder_decoder = False
test_pruning = False test_pruning = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -189,7 +190,7 @@ class OpenAIGPTModelTester: ...@@ -189,7 +190,7 @@ class OpenAIGPTModelTester:
@require_torch @require_torch
class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification) (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
if is_torch_available() if is_torch_available()
...@@ -198,6 +199,16 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -198,6 +199,16 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
all_generative_model_classes = ( all_generative_model_classes = (
(OpenAIGPTLMHeadModel,) if is_torch_available() else () (OpenAIGPTLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
pipeline_model_mapping = (
{
"feature-extraction": OpenAIGPTModel,
"text-classification": OpenAIGPTForSequenceClassification,
"text-generation": OpenAIGPTLMHeadModel,
"zero-shot": OpenAIGPTForSequenceClassification,
}
if is_torch_available()
else {}
)
# special case for DoubleHeads model # special case for DoubleHeads model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -191,7 +192,7 @@ class TFOpenAIGPTModelTester: ...@@ -191,7 +192,7 @@ class TFOpenAIGPTModelTester:
@require_tf @require_tf
class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): class TFOpenAIGPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTForSequenceClassification) (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTForSequenceClassification)
if is_tf_available() if is_tf_available()
...@@ -200,6 +201,16 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -200,6 +201,16 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else () (TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
pipeline_model_mapping = (
{
"feature-extraction": TFOpenAIGPTModel,
"text-classification": TFOpenAIGPTForSequenceClassification,
"text-generation": TFOpenAIGPTLMHeadModel,
"zero-shot": TFOpenAIGPTForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t ...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -183,13 +184,24 @@ class OPTModelTester: ...@@ -183,13 +184,24 @@ class OPTModelTester:
@require_torch @require_torch
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering) (OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": OPTModel,
"question-answering": OPTForQuestionAnswering,
"text-classification": OPTForSequenceClassification,
"text-generation": OPTForCausalLM,
"zero-shot": OPTForSequenceClassification,
}
if is_torch_available()
else {}
)
is_encoder_decoder = False is_encoder_decoder = False
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, slow, ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, slow,
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -146,9 +147,12 @@ class TFOPTModelTester: ...@@ -146,9 +147,12 @@ class TFOPTModelTester:
@require_tf @require_tf
class TFOPTModelTest(TFModelTesterMixin, unittest.TestCase): class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else () all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else ()
all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else () all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFOPTModel, "text-generation": TFOPTForCausalLM} if is_tf_available() else {}
)
is_encoder_decoder = False is_encoder_decoder = False
test_pruning = False test_pruning = False
test_onnx = False test_onnx = False
......
...@@ -35,6 +35,7 @@ from ...test_modeling_common import ( ...@@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -393,8 +394,13 @@ class OwlViTModelTester: ...@@ -393,8 +394,13 @@ class OwlViTModelTester:
@require_torch @require_torch
class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (OwlViTModel,) if is_torch_available() else () all_model_classes = (OwlViTModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": OwlViTModel, "zero-shot-object-detection": OwlViTForObjectDetection}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
from ..mbart.test_modeling_mbart import AbstractSeq2SeqIntegrationTest from ..mbart.test_modeling_mbart import AbstractSeq2SeqIntegrationTest
...@@ -233,9 +234,20 @@ class PegasusModelTester: ...@@ -233,9 +234,20 @@ class PegasusModelTester:
@require_torch @require_torch
class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else () all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": PegasusForConditionalGeneration,
"feature-extraction": PegasusModel,
"summarization": PegasusForConditionalGeneration,
"text2text-generation": PegasusForConditionalGeneration,
"text-generation": PegasusForCausalLM,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True fx_compatible = True
test_resize_position_embeddings = True test_resize_position_embeddings = True
......
...@@ -22,6 +22,7 @@ from transformers.utils import cached_property ...@@ -22,6 +22,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -175,9 +176,19 @@ def prepare_pegasus_inputs_dict( ...@@ -175,9 +176,19 @@ def prepare_pegasus_inputs_dict(
@require_tf @require_tf
class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFPegasusForConditionalGeneration, TFPegasusModel) if is_tf_available() else () all_model_classes = (TFPegasusForConditionalGeneration, TFPegasusModel) if is_tf_available() else ()
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"conversational": TFPegasusForConditionalGeneration,
"feature-extraction": TFPegasusModel,
"summarization": TFPegasusForConditionalGeneration,
"text2text-generation": TFPegasusForConditionalGeneration,
}
if is_tf_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_onnx = False test_onnx = False
......
...@@ -27,6 +27,7 @@ from transformers.utils import cached_property ...@@ -27,6 +27,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -194,9 +195,19 @@ class PegasusXModelTester: ...@@ -194,9 +195,19 @@ class PegasusXModelTester:
@require_torch @require_torch
class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (PegasusXModel, PegasusXForConditionalGeneration) if is_torch_available() else () all_model_classes = (PegasusXModel, PegasusXForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusXForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PegasusXForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": PegasusXForConditionalGeneration,
"feature-extraction": PegasusXModel,
"summarization": PegasusXForConditionalGeneration,
"text2text-generation": PegasusXForConditionalGeneration,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
......
...@@ -32,6 +32,7 @@ from transformers.utils import is_torch_available, is_vision_available ...@@ -32,6 +32,7 @@ from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -264,7 +265,7 @@ class PerceiverModelTester: ...@@ -264,7 +265,7 @@ class PerceiverModelTester:
@require_torch @require_torch
class PerceiverModelTest(ModelTesterMixin, unittest.TestCase): class PerceiverModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
PerceiverModel, PerceiverModel,
...@@ -279,6 +280,21 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -279,6 +280,21 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": PerceiverModel,
"fill-mask": PerceiverForMaskedLM,
"image-classification": (
PerceiverForImageClassificationConvProcessing,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationLearned,
),
"text-classification": PerceiverForSequenceClassification,
"zero-shot": PerceiverForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_torchscript = False test_torchscript = False
......
...@@ -26,6 +26,7 @@ from transformers.utils import cached_property ...@@ -26,6 +26,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -213,11 +214,24 @@ class PLBartModelTester: ...@@ -213,11 +214,24 @@ class PLBartModelTester:
@require_torch @require_torch
class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(PLBartModel, PLBartForConditionalGeneration, PLBartForSequenceClassification) if is_torch_available() else () (PLBartModel, PLBartForConditionalGeneration, PLBartForSequenceClassification) if is_torch_available() else ()
) )
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": PLBartForConditionalGeneration,
"feature-extraction": PLBartModel,
"summarization": PLBartForConditionalGeneration,
"text2text-generation": PLBartForConditionalGeneration,
"text-classification": PLBartForSequenceClassification,
"text-generation": PLBartForCausalLM,
"zero-shot": PLBartForSequenceClassification,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = False # Fix me Michael fx_compatible = False # Fix me Michael
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -121,8 +122,13 @@ class PoolFormerModelTester: ...@@ -121,8 +122,13 @@ class PoolFormerModelTester:
@require_torch @require_torch
class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else () all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
if is_torch_available()
else {}
)
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment