"docs/vscode:/vscode.git/clone" did not exist on "7ca46335553609e4852dcb018c73cd5215e6e25a"
Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -374,7 +375,7 @@ class ElectraModelTester: ...@@ -374,7 +375,7 @@ class ElectraModelTester:
@require_torch @require_torch
class ElectraModelTest(ModelTesterMixin, unittest.TestCase): class ElectraModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ElectraModel, ElectraModel,
...@@ -389,6 +390,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -389,6 +390,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": ElectraModel,
"fill-mask": ElectraForMaskedLM,
"question-answering": ElectraForQuestionAnswering,
"text-classification": ElectraForSequenceClassification,
"text-generation": ElectraForCausalLM,
"token-classification": ElectraForTokenClassification,
"zero-shot": ElectraForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
# special case for ForPreTraining model # special case for ForPreTraining model
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -487,7 +488,7 @@ class TFElectraModelTester: ...@@ -487,7 +488,7 @@ class TFElectraModelTester:
@require_tf @require_tf
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): class TFElectraModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFElectraModel, TFElectraModel,
...@@ -501,6 +502,18 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -501,6 +502,18 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFElectraModel,
"fill-mask": TFElectraForMaskedLM,
"question-answering": TFElectraForQuestionAnswering,
"text-classification": TFElectraForSequenceClassification,
"token-classification": TFElectraForTokenClassification,
"zero-shot": TFElectraForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -426,7 +427,7 @@ class ErnieModelTester: ...@@ -426,7 +427,7 @@ class ErnieModelTester:
@require_torch @require_torch
class ErnieModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ErnieModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ErnieModel, ErnieModel,
...@@ -443,6 +444,19 @@ class ErnieModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -443,6 +444,19 @@ class ErnieModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
else () else ()
) )
all_generative_model_classes = (ErnieForCausalLM,) if is_torch_available() else () all_generative_model_classes = (ErnieForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ErnieModel,
"fill-mask": ErnieForMaskedLM,
"question-answering": ErnieForQuestionAnswering,
"text-classification": ErnieForSequenceClassification,
"text-generation": ErnieForCausalLM,
"token-classification": ErnieForTokenClassification,
"zero-shot": ErnieForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
# special case for ForPreTraining model # special case for ForPreTraining model
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -223,7 +224,7 @@ class ErnieMModelTester: ...@@ -223,7 +224,7 @@ class ErnieMModelTester:
@require_torch @require_torch
class ErnieMModelTest(ModelTesterMixin, unittest.TestCase): class ErnieMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ErnieMModel, ErnieMModel,
...@@ -236,6 +237,17 @@ class ErnieMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -236,6 +237,17 @@ class ErnieMModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = () all_generative_model_classes = ()
pipeline_model_mapping = (
{
"feature-extraction": ErnieMModel,
"question-answering": ErnieMForQuestionAnswering,
"text-classification": ErnieMForSequenceClassification,
"token-classification": ErnieMForTokenClassification,
"zero-shot": ErnieMForSequenceClassification,
}
if is_torch_available()
else {}
)
test_torchscript = False test_torchscript = False
def setUp(self): def setUp(self):
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_ ...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -165,7 +166,7 @@ class EsmModelTester: ...@@ -165,7 +166,7 @@ class EsmModelTester:
@require_torch @require_torch
class EsmModelTest(ModelTesterMixin, unittest.TestCase): class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_mismatched_shapes = False test_mismatched_shapes = False
all_model_classes = ( all_model_classes = (
...@@ -179,6 +180,17 @@ class EsmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -179,6 +180,17 @@ class EsmModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = () all_generative_model_classes = ()
pipeline_model_mapping = (
{
"feature-extraction": EsmModel,
"fill-mask": EsmForMaskedLM,
"text-classification": EsmForSequenceClassification,
"token-classification": EsmForTokenClassification,
"zero-shot": EsmForSequenceClassification,
}
if is_torch_available()
else {}
)
test_sequence_classification_problem_types = True test_sequence_classification_problem_types = True
def setUp(self): def setUp(self):
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_ ...@@ -22,6 +22,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -143,11 +144,12 @@ class EsmFoldModelTester: ...@@ -143,11 +144,12 @@ class EsmFoldModelTester:
@require_torch @require_torch
class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase): class EsmFoldModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_mismatched_shapes = False test_mismatched_shapes = False
all_model_classes = (EsmForProteinFolding,) if is_torch_available() else () all_model_classes = (EsmForProteinFolding,) if is_torch_available() else ()
all_generative_model_classes = () all_generative_model_classes = ()
pipeline_model_mapping = {} if is_torch_available() else {}
test_sequence_classification_problem_types = False test_sequence_classification_problem_types = False
def setUp(self): def setUp(self):
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -194,7 +195,7 @@ class TFEsmModelTester: ...@@ -194,7 +195,7 @@ class TFEsmModelTester:
@require_tf @require_tf
class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase): class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFEsmModel, TFEsmModel,
...@@ -205,6 +206,17 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -205,6 +206,17 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFEsmModel,
"fill-mask": TFEsmForMaskedLM,
"text-classification": TFEsmForSequenceClassification,
"token-classification": TFEsmForTokenClassification,
"zero-shot": TFEsmForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -362,7 +363,7 @@ class FlaubertModelTester(object): ...@@ -362,7 +363,7 @@ class FlaubertModelTester(object):
@require_torch @require_torch
class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): class FlaubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FlaubertModel, FlaubertModel,
...@@ -376,6 +377,18 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -376,6 +377,18 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": FlaubertModel,
"fill-mask": FlaubertWithLMHeadModel,
"question-answering": FlaubertForQuestionAnsweringSimple,
"text-classification": FlaubertForSequenceClassification,
"token-classification": FlaubertForTokenClassification,
"zero-shot": FlaubertForSequenceClassification,
}
if is_torch_available()
else {}
)
# Flaubert has 2 QA models -> need to manually set the correct labels for one of them here # Flaubert has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir ...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -274,7 +275,7 @@ class TFFlaubertModelTester: ...@@ -274,7 +275,7 @@ class TFFlaubertModelTester:
@require_tf @require_tf
class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): class TFFlaubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFFlaubertModel, TFFlaubertModel,
...@@ -290,6 +291,18 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -290,6 +291,18 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else () (TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
pipeline_model_mapping = (
{
"feature-extraction": TFFlaubertModel,
"fill-mask": TFFlaubertWithLMHeadModel,
"question-answering": TFFlaubertForQuestionAnsweringSimple,
"text-classification": TFFlaubertForSequenceClassification,
"token-classification": TFFlaubertForTokenClassification,
"zero-shot": TFFlaubertForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -42,6 +42,7 @@ from ...test_modeling_common import ( ...@@ -42,6 +42,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -856,8 +857,9 @@ class FlavaModelTester: ...@@ -856,8 +857,9 @@ class FlavaModelTester:
@require_torch @require_torch
class FlavaModelTest(ModelTesterMixin, unittest.TestCase): class FlavaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FlavaModel,) if is_torch_available() else () all_model_classes = (FlavaModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": FlavaModel} if is_torch_available() else {}
class_for_tester = FlavaModelTester class_for_tester = FlavaModelTester
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tokenizers, require_torch, slow, ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_tokenizers, require_torch, slow,
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -265,7 +266,7 @@ class FNetModelTester: ...@@ -265,7 +266,7 @@ class FNetModelTester:
@require_torch @require_torch
class FNetModelTest(ModelTesterMixin, unittest.TestCase): class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FNetModel, FNetModel,
...@@ -280,6 +281,18 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -280,6 +281,18 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": FNetModel,
"fill-mask": FNetForMaskedLM,
"question-answering": FNetForQuestionAnswering,
"text-classification": FNetForSequenceClassification,
"token-classification": FNetForTokenClassification,
"zero-shot": FNetForSequenceClassification,
}
if is_torch_available()
else {}
)
# Skip Tests # Skip Tests
test_pruning = False test_pruning = False
......
...@@ -26,6 +26,7 @@ from transformers.utils import cached_property ...@@ -26,6 +26,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -153,9 +154,19 @@ def prepare_fsmt_inputs_dict( ...@@ -153,9 +154,19 @@ def prepare_fsmt_inputs_dict(
@require_torch @require_torch
class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else () all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": FSMTForConditionalGeneration,
"feature-extraction": FSMTModel,
"summarization": FSMTForConditionalGeneration,
"text2text-generation": FSMTForConditionalGeneration,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -351,7 +352,7 @@ class FunnelModelTester: ...@@ -351,7 +352,7 @@ class FunnelModelTester:
@require_torch @require_torch
class FunnelModelTest(ModelTesterMixin, unittest.TestCase): class FunnelModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
all_model_classes = ( all_model_classes = (
...@@ -365,6 +366,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -365,6 +366,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": (FunnelBaseModel, FunnelModel),
"fill-mask": FunnelForMaskedLM,
"question-answering": FunnelForQuestionAnswering,
"text-classification": FunnelForSequenceClassification,
"token-classification": FunnelForTokenClassification,
"zero-shot": FunnelForSequenceClassification,
}
if is_torch_available()
else {}
)
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, tooslow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, tooslow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -329,7 +330,7 @@ class TFFunnelModelTester: ...@@ -329,7 +330,7 @@ class TFFunnelModelTester:
@require_tf @require_tf
class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): class TFFunnelModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFFunnelModel, TFFunnelModel,
...@@ -341,6 +342,18 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -341,6 +342,18 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": (TFFunnelBaseModel, TFFunnelModel),
"fill-mask": TFFunnelForMaskedLM,
"question-answering": TFFunnelForQuestionAnswering,
"text-classification": TFFunnelForSequenceClassification,
"token-classification": TFFunnelForTokenClassification,
"zero-shot": TFFunnelForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -25,6 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -25,6 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -378,9 +379,12 @@ class GitModelTester: ...@@ -378,9 +379,12 @@ class GitModelTester:
@require_torch @require_torch
class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else () all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GitModel, "text-generation": GitForCausalLM} if is_torch_available() else {}
)
fx_compatible = False fx_compatible = False
test_torchscript = False test_torchscript = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -143,8 +144,11 @@ class GLPNModelTester: ...@@ -143,8 +144,11 @@ class GLPNModelTester:
@require_torch @require_torch
class GLPNModelTest(ModelTesterMixin, unittest.TestCase): class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else () all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
pipeline_model_mapping = (
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {}
)
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -429,13 +430,24 @@ class GPT2ModelTester: ...@@ -429,13 +430,24 @@ class GPT2ModelTester:
@require_torch @require_torch
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, GPT2ForTokenClassification) (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, GPT2ForTokenClassification)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPT2Model,
"text-classification": GPT2ForSequenceClassification,
"text-generation": GPT2LMHeadModel,
"token-classification": GPT2ForTokenClassification,
"zero-shot": GPT2ForSequenceClassification,
}
if is_torch_available()
else {}
)
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
fx_compatible = True fx_compatible = True
test_missing_keys = False test_missing_keys = False
......
...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_tf, require_tf2onnx, slow ...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_tf, require_tf2onnx, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
...@@ -361,13 +362,23 @@ class TFGPT2ModelTester: ...@@ -361,13 +362,23 @@ class TFGPT2ModelTester:
@require_tf @require_tf
class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase): class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFGPT2Model, TFGPT2LMHeadModel, TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel) (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel)
if is_tf_available() if is_tf_available()
else () else ()
) )
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else () all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": TFGPT2Model,
"text-classification": TFGPT2ForSequenceClassification,
"text-generation": TFGPT2LMHeadModel,
"zero-shot": TFGPT2ForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = True test_onnx = True
onnx_min_opset = 10 onnx_min_opset = 10
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -371,11 +372,21 @@ class GPTNeoModelTester: ...@@ -371,11 +372,21 @@ class GPTNeoModelTester:
@require_torch @require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
) )
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPTNeoModel,
"text-classification": GPTNeoForSequenceClassification,
"text-generation": GPTNeoForCausalLM,
"zero-shot": GPTNeoForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_pruning = False test_pruning = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -185,9 +186,12 @@ class GPTNeoXModelTester: ...@@ -185,9 +186,12 @@ class GPTNeoXModelTester:
@require_torch @require_torch
class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase): class GPTNeoXModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else () all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {}
)
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
test_model_parallel = False test_model_parallel = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment