Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
......@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -885,9 +886,20 @@ class ProphetNetStandaloneEncoderModelTester:
@require_torch
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": ProphetNetForConditionalGeneration,
"feature-extraction": ProphetNetModel,
"summarization": ProphetNetForConditionalGeneration,
"text2text-generation": ProphetNetForConditionalGeneration,
"text-generation": ProphetNetForCausalLM,
}
if is_torch_available()
else {}
)
test_pruning = False
test_resize_embeddings = False
is_encoder_decoder = True
......
......@@ -23,6 +23,7 @@ from transformers.testing_utils import require_pytorch_quantization, require_tor
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -419,7 +420,7 @@ class QDQBertModelTester:
@require_torch
@require_pytorch_quantization
class QDQBertModelTest(ModelTesterMixin, unittest.TestCase):
class QDQBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
QDQBertModel,
......@@ -435,6 +436,19 @@ class QDQBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (QDQBertLMHeadModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": QDQBertModel,
"fill-mask": QDQBertForMaskedLM,
"question-answering": QDQBertForQuestionAnswering,
"text-classification": QDQBertForSequenceClassification,
"text-generation": QDQBertLMHeadModel,
"token-classification": QDQBertForTokenClassification,
"zero-shot": QDQBertForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self):
self.model_tester = QDQBertModelTester(self)
......
......@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -303,7 +304,7 @@ class RealmModelTester:
@require_torch
class RealmModelTest(ModelTesterMixin, unittest.TestCase):
class RealmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RealmEmbedder,
......@@ -316,6 +317,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = ()
pipeline_model_mapping = {} if is_torch_available() else {}
# disable these tests because there is no base_model in Realm
test_save_load_fast_init_from_base = False
......
......@@ -28,6 +28,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -683,13 +684,27 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class ReformerLSHAttnModelTest(
ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase
):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ReformerModel,
"fill-mask": ReformerForMaskedLM,
"question-answering": ReformerForQuestionAnswering,
"text-classification": ReformerForSequenceClassification,
"text-generation": ReformerModelWithLMHead,
"zero-shot": ReformerForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False
test_headmasking = False
test_torchscript = False
......
......@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -118,13 +119,18 @@ class RegNetModelTester:
@require_torch
class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
if is_torch_available()
else {}
)
test_pruning = False
test_resize_embeddings = False
......
......@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -111,13 +112,18 @@ class TFRegNetModelTester:
@require_tf
class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase):
class TFRegNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFRegNetModel, "image-classification": TFRegNetForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False
test_onnx = False
......
......@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -360,7 +361,7 @@ class RemBertModelTester:
@require_torch
class RemBertModelTest(ModelTesterMixin, unittest.TestCase):
class RemBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RemBertModel,
......@@ -375,6 +376,19 @@ class RemBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (RemBertForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RemBertModel,
"fill-mask": RemBertForMaskedLM,
"question-answering": RemBertForQuestionAnswering,
"text-classification": RemBertForSequenceClassification,
"text-generation": RemBertForCausalLM,
"token-classification": RemBertForTokenClassification,
"zero-shot": RemBertForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self):
self.model_tester = RemBertModelTester(self)
......
......@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -570,7 +571,7 @@ class TFRemBertModelTester:
@require_tf
class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase):
class TFRemBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFRemBertModel,
......@@ -584,6 +585,19 @@ class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": TFRemBertModel,
"fill-mask": TFRemBertForMaskedLM,
"question-answering": TFRemBertForQuestionAnswering,
"text-classification": TFRemBertForSequenceClassification,
"text-generation": TFRemBertForCausalLM,
"token-classification": TFRemBertForTokenClassification,
"zero-shot": TFRemBertForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False
test_onnx = False
......
......@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -150,7 +151,7 @@ class ResNetModelTester:
@require_torch
class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
......@@ -165,6 +166,11 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True
test_pruning = False
......
......@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -116,13 +117,18 @@ class TFResNetModelTester:
@require_tf
class TFResNetModelTest(TFModelTesterMixin, unittest.TestCase):
class TFResNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFResNetModel, "image-classification": TFResNetForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False
test_resize_embeddings = False
......
......@@ -23,6 +23,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -366,7 +367,7 @@ class RobertaModelTester:
@require_torch
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RobertaForCausalLM,
......@@ -381,6 +382,19 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RobertaModel,
"fill-mask": RobertaForMaskedLM,
"question-answering": RobertaForQuestionAnswering,
"text-classification": RobertaForSequenceClassification,
"text-generation": RobertaForCausalLM,
"token-classification": RobertaForTokenClassification,
"zero-shot": RobertaForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True
def setUp(self):
......
......@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -547,7 +548,7 @@ class TFRobertaModelTester:
@require_tf
class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
class TFRobertaModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFRobertaModel,
......@@ -560,6 +561,19 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": TFRobertaModel,
"fill-mask": TFRobertaForMaskedLM,
"question-answering": TFRobertaForQuestionAnswering,
"text-classification": TFRobertaForSequenceClassification,
"text-generation": TFRobertaForCausalLM,
"token-classification": TFRobertaForTokenClassification,
"zero-shot": TFRobertaForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False
test_onnx = False
......
......@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -365,7 +366,7 @@ class RobertaPreLayerNormModelTester:
@require_torch
# Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RobertaPreLayerNormForCausalLM,
......@@ -380,6 +381,19 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, unit
else ()
)
all_generative_model_classes = (RobertaPreLayerNormForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RobertaPreLayerNormModel,
"fill-mask": RobertaPreLayerNormForMaskedLM,
"question-answering": RobertaPreLayerNormForQuestionAnswering,
"text-classification": RobertaPreLayerNormForSequenceClassification,
"text-generation": RobertaPreLayerNormForCausalLM,
"token-classification": RobertaPreLayerNormForTokenClassification,
"zero-shot": RobertaPreLayerNormForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False
def setUp(self):
......
......@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -549,7 +550,7 @@ class TFRobertaPreLayerNormModelTester:
@require_tf
# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, unittest.TestCase):
class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFRobertaPreLayerNormModel,
......@@ -562,6 +563,19 @@ class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": TFRobertaPreLayerNormModel,
"fill-mask": TFRobertaPreLayerNormForMaskedLM,
"question-answering": TFRobertaPreLayerNormForQuestionAnswering,
"text-classification": TFRobertaPreLayerNormForSequenceClassification,
"text-generation": TFRobertaPreLayerNormForCausalLM,
"token-classification": TFRobertaPreLayerNormForTokenClassification,
"zero-shot": TFRobertaPreLayerNormForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False
test_onnx = False
......
......@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -555,7 +556,7 @@ class RoCBertModelTester:
@require_torch
class RoCBertModelTest(ModelTesterMixin, unittest.TestCase):
class RoCBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RoCBertModel,
......@@ -571,6 +572,19 @@ class RoCBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (RoCBertForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RoCBertModel,
"fill-mask": RoCBertForMaskedLM,
"question-answering": RoCBertForQuestionAnswering,
"text-classification": RoCBertForSequenceClassification,
"text-generation": RoCBertForCausalLM,
"token-classification": RoCBertForTokenClassification,
"zero-shot": RoCBertForSequenceClassification,
}
if is_torch_available()
else {}
)
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -360,7 +361,7 @@ class RoFormerModelTester:
@require_torch
class RoFormerModelTest(ModelTesterMixin, unittest.TestCase):
class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
RoFormerModel,
......@@ -375,6 +376,19 @@ class RoFormerModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (RoFormerForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RoFormerModel,
"fill-mask": RoFormerForMaskedLM,
"question-answering": RoFormerForQuestionAnswering,
"text-classification": RoFormerForSequenceClassification,
"text-generation": RoFormerForCausalLM,
"token-classification": RoFormerForTokenClassification,
"zero-shot": RoFormerForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self):
self.model_tester = RoFormerModelTester(self)
......
......@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -239,7 +240,7 @@ class TFRoFormerModelTester:
@require_tf
class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase):
class TFRoFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFRoFormerModel,
......@@ -253,6 +254,19 @@ class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": TFRoFormerModel,
"fill-mask": TFRoFormerForMaskedLM,
"question-answering": TFRoFormerForQuestionAnswering,
"text-classification": TFRoFormerForSequenceClassification,
"text-generation": TFRoFormerForCausalLM,
"token-classification": TFRoFormerForTokenClassification,
"zero-shot": TFRoFormerForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False
test_onnx = False
......
......@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -159,7 +160,7 @@ class SegformerModelTester:
@require_torch
class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
SegformerModel,
......@@ -169,6 +170,15 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": SegformerModel,
"image-classification": SegformerForImageClassification,
"image-segmentation": SegformerForSemanticSegmentation,
}
if is_torch_available()
else {}
)
fx_compatible = True
test_head_masking = False
......
......@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
......@@ -150,12 +151,17 @@ class TFSegformerModelTester:
@require_tf
class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase):
class TFSegformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": TFSegformerModel, "image-classification": TFSegformerForImageClassification}
if is_tf_available()
else {}
)
test_head_masking = False
test_onnx = False
......
......@@ -31,6 +31,7 @@ from ...test_modeling_common import (
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -299,8 +300,17 @@ class SEWModelTester:
@require_torch
class SEWModelTest(ModelTesterMixin, unittest.TestCase):
class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{
"audio-classification": SEWForSequenceClassification,
"automatic-speech-recognition": SEWForCTC,
"feature-extraction": SEWModel,
}
if is_torch_available()
else {}
)
test_pruning = False
test_headmasking = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment