Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -885,9 +886,20 @@ class ProphetNetStandaloneEncoderModelTester: ...@@ -885,9 +886,20 @@ class ProphetNetStandaloneEncoderModelTester:
@require_torch @require_torch
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else () all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": ProphetNetForConditionalGeneration,
"feature-extraction": ProphetNetModel,
"summarization": ProphetNetForConditionalGeneration,
"text2text-generation": ProphetNetForConditionalGeneration,
"text-generation": ProphetNetForCausalLM,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
is_encoder_decoder = True is_encoder_decoder = True
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_pytorch_quantization, require_tor ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_pytorch_quantization, require_tor
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -419,7 +420,7 @@ class QDQBertModelTester: ...@@ -419,7 +420,7 @@ class QDQBertModelTester:
@require_torch @require_torch
@require_pytorch_quantization @require_pytorch_quantization
class QDQBertModelTest(ModelTesterMixin, unittest.TestCase): class QDQBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
QDQBertModel, QDQBertModel,
...@@ -435,6 +436,19 @@ class QDQBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -435,6 +436,19 @@ class QDQBertModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (QDQBertLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (QDQBertLMHeadModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": QDQBertModel,
"fill-mask": QDQBertForMaskedLM,
"question-answering": QDQBertForQuestionAnswering,
"text-classification": QDQBertForSequenceClassification,
"text-generation": QDQBertLMHeadModel,
"token-classification": QDQBertForTokenClassification,
"zero-shot": QDQBertForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = QDQBertModelTester(self) self.model_tester = QDQBertModelTester(self)
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -303,7 +304,7 @@ class RealmModelTester: ...@@ -303,7 +304,7 @@ class RealmModelTester:
@require_torch @require_torch
class RealmModelTest(ModelTesterMixin, unittest.TestCase): class RealmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RealmEmbedder, RealmEmbedder,
...@@ -316,6 +317,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,6 +317,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = () all_generative_model_classes = ()
pipeline_model_mapping = {} if is_torch_available() else {}
# disable these tests because there is no base_model in Realm # disable these tests because there is no base_model in Realm
test_save_load_fast_init_from_base = False test_save_load_fast_init_from_base = False
......
...@@ -28,6 +28,7 @@ from transformers.testing_utils import ( ...@@ -28,6 +28,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -683,13 +684,27 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod ...@@ -683,13 +684,27 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
@require_torch @require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ReformerLSHAttnModelTest(
ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase
):
all_model_classes = ( all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering) (ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ReformerModel,
"fill-mask": ReformerForMaskedLM,
"question-answering": ReformerForQuestionAnswering,
"text-classification": ReformerForSequenceClassification,
"text-generation": ReformerModelWithLMHead,
"zero-shot": ReformerForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -118,13 +119,18 @@ class RegNetModelTester: ...@@ -118,13 +119,18 @@ class RegNetModelTester:
@require_torch @require_torch
class RegNetModelTest(ModelTesterMixin, unittest.TestCase): class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else () all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -111,13 +112,18 @@ class TFRegNetModelTester: ...@@ -111,13 +112,18 @@ class TFRegNetModelTester:
@require_tf @require_tf
class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase): class TFRegNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else () all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFRegNetModel, "image-classification": TFRegNetForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_onnx = False test_onnx = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -360,7 +361,7 @@ class RemBertModelTester: ...@@ -360,7 +361,7 @@ class RemBertModelTester:
@require_torch @require_torch
class RemBertModelTest(ModelTesterMixin, unittest.TestCase): class RemBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RemBertModel, RemBertModel,
...@@ -375,6 +376,19 @@ class RemBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -375,6 +376,19 @@ class RemBertModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (RemBertForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RemBertForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RemBertModel,
"fill-mask": RemBertForMaskedLM,
"question-answering": RemBertForQuestionAnswering,
"text-classification": RemBertForSequenceClassification,
"text-generation": RemBertForCausalLM,
"token-classification": RemBertForTokenClassification,
"zero-shot": RemBertForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = RemBertModelTester(self) self.model_tester = RemBertModelTester(self)
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -570,7 +571,7 @@ class TFRemBertModelTester: ...@@ -570,7 +571,7 @@ class TFRemBertModelTester:
@require_tf @require_tf
class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase): class TFRemBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFRemBertModel, TFRemBertModel,
...@@ -584,6 +585,19 @@ class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -584,6 +585,19 @@ class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFRemBertModel,
"fill-mask": TFRemBertForMaskedLM,
"question-answering": TFRemBertForQuestionAnswering,
"text-classification": TFRemBertForSequenceClassification,
"text-generation": TFRemBertForCausalLM,
"token-classification": TFRemBertForTokenClassification,
"zero-shot": TFRemBertForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -150,7 +151,7 @@ class ResNetModelTester: ...@@ -150,7 +151,7 @@ class ResNetModelTester:
@require_torch @require_torch
class ResNetModelTest(ModelTesterMixin, unittest.TestCase): class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -165,6 +166,11 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -165,6 +166,11 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail ...@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -116,13 +117,18 @@ class TFResNetModelTester: ...@@ -116,13 +117,18 @@ class TFResNetModelTester:
@require_tf @require_tf
class TFResNetModelTest(TFModelTesterMixin, unittest.TestCase): class TFResNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else () all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFResNetModel, "image-classification": TFResNetForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_ ...@@ -23,6 +23,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -366,7 +367,7 @@ class RobertaModelTester: ...@@ -366,7 +367,7 @@ class RobertaModelTester:
@require_torch @require_torch
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RobertaForCausalLM, RobertaForCausalLM,
...@@ -381,6 +382,19 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -381,6 +382,19 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else () else ()
) )
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RobertaModel,
"fill-mask": RobertaForMaskedLM,
"question-answering": RobertaForQuestionAnswering,
"text-classification": RobertaForSequenceClassification,
"text-generation": RobertaForCausalLM,
"token-classification": RobertaForTokenClassification,
"zero-shot": RobertaForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
def setUp(self): def setUp(self):
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -547,7 +548,7 @@ class TFRobertaModelTester: ...@@ -547,7 +548,7 @@ class TFRobertaModelTester:
@require_tf @require_tf
class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): class TFRobertaModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFRobertaModel, TFRobertaModel,
...@@ -560,6 +561,19 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -560,6 +561,19 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFRobertaModel,
"fill-mask": TFRobertaForMaskedLM,
"question-answering": TFRobertaForQuestionAnswering,
"text-classification": TFRobertaForSequenceClassification,
"text-generation": TFRobertaForCausalLM,
"token-classification": TFRobertaForTokenClassification,
"zero-shot": TFRobertaForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_ ...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -365,7 +366,7 @@ class RobertaPreLayerNormModelTester: ...@@ -365,7 +366,7 @@ class RobertaPreLayerNormModelTester:
@require_torch @require_torch
# Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm # Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RobertaPreLayerNormForCausalLM, RobertaPreLayerNormForCausalLM,
...@@ -380,6 +381,19 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, unit ...@@ -380,6 +381,19 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, unit
else () else ()
) )
all_generative_model_classes = (RobertaPreLayerNormForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RobertaPreLayerNormForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RobertaPreLayerNormModel,
"fill-mask": RobertaPreLayerNormForMaskedLM,
"question-answering": RobertaPreLayerNormForQuestionAnswering,
"text-classification": RobertaPreLayerNormForSequenceClassification,
"text-generation": RobertaPreLayerNormForCausalLM,
"token-classification": RobertaPreLayerNormForTokenClassification,
"zero-shot": RobertaPreLayerNormForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
def setUp(self): def setUp(self):
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -549,7 +550,7 @@ class TFRobertaPreLayerNormModelTester: ...@@ -549,7 +550,7 @@ class TFRobertaPreLayerNormModelTester:
@require_tf @require_tf
# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm # Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, unittest.TestCase): class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFRobertaPreLayerNormModel, TFRobertaPreLayerNormModel,
...@@ -562,6 +563,19 @@ class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -562,6 +563,19 @@ class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFRobertaPreLayerNormModel,
"fill-mask": TFRobertaPreLayerNormForMaskedLM,
"question-answering": TFRobertaPreLayerNormForQuestionAnswering,
"text-classification": TFRobertaPreLayerNormForSequenceClassification,
"text-generation": TFRobertaPreLayerNormForCausalLM,
"token-classification": TFRobertaPreLayerNormForTokenClassification,
"zero-shot": TFRobertaPreLayerNormForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -555,7 +556,7 @@ class RoCBertModelTester: ...@@ -555,7 +556,7 @@ class RoCBertModelTester:
@require_torch @require_torch
class RoCBertModelTest(ModelTesterMixin, unittest.TestCase): class RoCBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RoCBertModel, RoCBertModel,
...@@ -571,6 +572,19 @@ class RoCBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -571,6 +572,19 @@ class RoCBertModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (RoCBertForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RoCBertForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RoCBertModel,
"fill-mask": RoCBertForMaskedLM,
"question-answering": RoCBertForQuestionAnswering,
"text-classification": RoCBertForSequenceClassification,
"text-generation": RoCBertForCausalLM,
"token-classification": RoCBertForTokenClassification,
"zero-shot": RoCBertForSequenceClassification,
}
if is_torch_available()
else {}
)
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -360,7 +361,7 @@ class RoFormerModelTester: ...@@ -360,7 +361,7 @@ class RoFormerModelTester:
@require_torch @require_torch
class RoFormerModelTest(ModelTesterMixin, unittest.TestCase): class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
RoFormerModel, RoFormerModel,
...@@ -375,6 +376,19 @@ class RoFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -375,6 +376,19 @@ class RoFormerModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (RoFormerForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RoFormerForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RoFormerModel,
"fill-mask": RoFormerForMaskedLM,
"question-answering": RoFormerForQuestionAnswering,
"text-classification": RoFormerForSequenceClassification,
"text-generation": RoFormerForCausalLM,
"token-classification": RoFormerForTokenClassification,
"zero-shot": RoFormerForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = RoFormerModelTester(self) self.model_tester = RoFormerModelTester(self)
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -239,7 +240,7 @@ class TFRoFormerModelTester: ...@@ -239,7 +240,7 @@ class TFRoFormerModelTester:
@require_tf @require_tf
class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase): class TFRoFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFRoFormerModel, TFRoFormerModel,
...@@ -253,6 +254,19 @@ class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -253,6 +254,19 @@ class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFRoFormerModel,
"fill-mask": TFRoFormerForMaskedLM,
"question-answering": TFRoFormerForQuestionAnswering,
"text-classification": TFRoFormerForSequenceClassification,
"text-generation": TFRoFormerForCausalLM,
"token-classification": TFRoFormerForTokenClassification,
"zero-shot": TFRoFormerForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -159,7 +160,7 @@ class SegformerModelTester: ...@@ -159,7 +160,7 @@ class SegformerModelTester:
@require_torch @require_torch
class SegformerModelTest(ModelTesterMixin, unittest.TestCase): class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
SegformerModel, SegformerModel,
...@@ -169,6 +170,15 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -169,6 +170,15 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": SegformerModel,
"image-classification": SegformerForImageClassification,
"image-segmentation": SegformerForSemanticSegmentation,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_head_masking = False test_head_masking = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -150,12 +151,17 @@ class TFSegformerModelTester: ...@@ -150,12 +151,17 @@ class TFSegformerModelTester:
@require_tf @require_tf
class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase): class TFSegformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation) (TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": TFSegformerModel, "image-classification": TFSegformerForImageClassification}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -31,6 +31,7 @@ from ...test_modeling_common import ( ...@@ -31,6 +31,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -299,8 +300,17 @@ class SEWModelTester: ...@@ -299,8 +300,17 @@ class SEWModelTester:
@require_torch @require_torch
class SEWModelTest(ModelTesterMixin, unittest.TestCase): class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else () all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{
"audio-classification": SEWForSequenceClassification,
"automatic-speech-recognition": SEWForCTC,
"feature-extraction": SEWModel,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment