Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -189,9 +190,14 @@ class GPTNeoXJapaneseModelTester: ...@@ -189,9 +190,14 @@ class GPTNeoXJapaneseModelTester:
@require_torch @require_torch
class GPTNeoXModelJapaneseTest(ModelTesterMixin, unittest.TestCase): class GPTNeoXModelJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else () all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GPTNeoXJapaneseModel, "text-generation": GPTNeoXJapaneseForCausalLM}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
test_model_parallel = False test_model_parallel = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, tooslow, torch_devic ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, tooslow, torch_devic
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -360,13 +361,24 @@ class GPTJModelTester: ...@@ -360,13 +361,24 @@ class GPTJModelTester:
@require_torch @require_torch
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering) (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPTJModel,
"question-answering": GPTJForQuestionAnswering,
"text-classification": GPTJForSequenceClassification,
"text-generation": GPTJForCausalLM,
"zero-shot": GPTJForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_tf, slow, tooslow ...@@ -20,6 +20,7 @@ from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
...@@ -293,7 +294,7 @@ class TFGPTJModelTester: ...@@ -293,7 +294,7 @@ class TFGPTJModelTester:
@require_tf @require_tf
class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase): class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFGPTJForCausalLM, TFGPTJForSequenceClassification, TFGPTJForQuestionAnswering, TFGPTJModel) (TFGPTJForCausalLM, TFGPTJForSequenceClassification, TFGPTJForQuestionAnswering, TFGPTJModel)
if is_tf_available() if is_tf_available()
...@@ -301,6 +302,17 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC ...@@ -301,6 +302,17 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
) )
all_generative_model_classes = (TFGPTJForCausalLM,) if is_tf_available() else () all_generative_model_classes = (TFGPTJForCausalLM,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": TFGPTJModel,
"question-answering": TFGPTJForQuestionAnswering,
"text-classification": TFGPTJForSequenceClassification,
"text-generation": TFGPTJForCausalLM,
"zero-shot": TFGPTJForSequenceClassification,
}
if is_tf_available()
else {}
)
test_onnx = False test_onnx = False
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -26,6 +26,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -26,6 +26,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -243,9 +244,10 @@ class GraphormerModelTester: ...@@ -243,9 +244,10 @@ class GraphormerModelTester:
@require_torch @require_torch
class GraphormerModelTest(ModelTesterMixin, unittest.TestCase): class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GraphormerForGraphClassification, GraphormerModel) if is_torch_available() else () all_model_classes = (GraphormerForGraphClassification, GraphormerModel) if is_torch_available() else ()
all_generative_model_classes = () all_generative_model_classes = ()
pipeline_model_mapping = {"feature-extraction": GraphormerModel} if is_torch_available() else {}
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -36,6 +36,7 @@ from ...test_modeling_common import ( ...@@ -36,6 +36,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -521,8 +522,9 @@ class GroupViTModelTester: ...@@ -521,8 +522,9 @@ class GroupViTModelTester:
@require_torch @require_torch
class GroupViTModelTest(ModelTesterMixin, unittest.TestCase): class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GroupViTModel,) if is_torch_available() else () all_model_classes = (GroupViTModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": GroupViTModel} if is_torch_available() else {}
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -37,6 +37,7 @@ from transformers.utils import is_tf_available, is_vision_available ...@@ -37,6 +37,7 @@ from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -569,8 +570,9 @@ class TFGroupViTModelTester: ...@@ -569,8 +570,9 @@ class TFGroupViTModelTester:
@require_tf @require_tf
class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase): class TFGroupViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFGroupViTModel,) if is_tf_available() else () all_model_classes = (TFGroupViTModel,) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFGroupViTModel} if is_tf_available() else {}
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -35,6 +35,7 @@ from ...test_modeling_common import ( ...@@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -304,8 +305,17 @@ class HubertModelTester: ...@@ -304,8 +305,17 @@ class HubertModelTester:
@require_torch @require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase): class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
pipeline_model_mapping = (
{
"audio-classification": HubertForSequenceClassification,
"automatic-speech-recognition": HubertForCTC,
"feature-extraction": HubertModel,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
......
...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_soundfile, require_tf, slow ...@@ -27,6 +27,7 @@ from transformers.testing_utils import require_soundfile, require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -218,8 +219,9 @@ class TFHubertModelTester: ...@@ -218,8 +219,9 @@ class TFHubertModelTester:
@require_tf @require_tf
class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFHubertModel, TFHubertForCTC) if is_tf_available() else () all_model_classes = (TFHubertModel, TFHubertForCTC) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFHubertModel} if is_tf_available() else {}
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -224,7 +225,7 @@ class IBertModelTester: ...@@ -224,7 +225,7 @@ class IBertModelTester:
@require_torch @require_torch
class IBertModelTest(ModelTesterMixin, unittest.TestCase): class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_head_masking = False test_head_masking = False
...@@ -242,6 +243,18 @@ class IBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -242,6 +243,18 @@ class IBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": IBertModel,
"fill-mask": IBertForMaskedLM,
"question-answering": IBertForQuestionAnswering,
"text-classification": IBertForSequenceClassification,
"token-classification": IBertForTokenClassification,
"zero-shot": IBertForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = IBertModelTester(self) self.model_tester = IBertModelTester(self)
......
...@@ -33,6 +33,7 @@ from ...test_modeling_common import ( ...@@ -33,6 +33,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -264,11 +265,16 @@ class ImageGPTModelTester: ...@@ -264,11 +265,16 @@ class ImageGPTModelTester:
@require_torch @require_torch
class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else () (ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else ()
) )
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else () all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
if is_torch_available()
else {}
)
test_missing_keys = False test_missing_keys = False
input_name = "pixel_values" input_name = "pixel_values"
......
...@@ -19,6 +19,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -19,6 +19,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -219,7 +220,7 @@ class LayoutLMModelTester: ...@@ -219,7 +220,7 @@ class LayoutLMModelTester:
@require_torch @require_torch
class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
LayoutLMModel, LayoutLMModel,
...@@ -231,6 +232,18 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -231,6 +232,18 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
pipeline_model_mapping = (
{
"document-question-answering": LayoutLMForQuestionAnswering,
"feature-extraction": LayoutLMModel,
"fill-mask": LayoutLMForMaskedLM,
"text-classification": LayoutLMForSequenceClassification,
"token-classification": LayoutLMForTokenClassification,
"zero-shot": LayoutLMForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
def setUp(self): def setUp(self):
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -206,7 +207,7 @@ class TFLayoutLMModelTester: ...@@ -206,7 +207,7 @@ class TFLayoutLMModelTester:
@require_tf @require_tf
class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase): class TFLayoutLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFLayoutLMModel, TFLayoutLMModel,
...@@ -218,6 +219,17 @@ class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -218,6 +219,17 @@ class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFLayoutLMModel,
"fill-mask": TFLayoutLMForMaskedLM,
"text-classification": TFLayoutLMForSequenceClassification,
"token-classification": TFLayoutLMForTokenClassification,
"zero-shot": TFLayoutLMForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = True test_onnx = True
onnx_min_opset = 10 onnx_min_opset = 10
......
...@@ -22,6 +22,7 @@ from transformers.utils import is_detectron2_available, is_torch_available ...@@ -22,6 +22,7 @@ from transformers.utils import is_detectron2_available, is_torch_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -253,7 +254,7 @@ class LayoutLMv2ModelTester: ...@@ -253,7 +254,7 @@ class LayoutLMv2ModelTester:
@require_torch @require_torch
@require_detectron2 @require_detectron2
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): class LayoutLMv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_mismatched_shapes = False test_mismatched_shapes = False
...@@ -268,6 +269,18 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -268,6 +269,18 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"document-question-answering": LayoutLMv2ForQuestionAnswering,
"feature-extraction": LayoutLMv2Model,
"question-answering": LayoutLMv2ForQuestionAnswering,
"text-classification": LayoutLMv2ForSequenceClassification,
"token-classification": LayoutLMv2ForTokenClassification,
"zero-shot": LayoutLMv2ForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = LayoutLMv2ModelTester(self) self.model_tester = LayoutLMv2ModelTester(self)
......
...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -23,6 +23,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -269,7 +270,7 @@ class LayoutLMv3ModelTester: ...@@ -269,7 +270,7 @@ class LayoutLMv3ModelTester:
@require_torch @require_torch
class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase): class LayoutLMv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_mismatched_shapes = False test_mismatched_shapes = False
...@@ -284,6 +285,18 @@ class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -284,6 +285,18 @@ class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"document-question-answering": LayoutLMv3ForQuestionAnswering,
"feature-extraction": LayoutLMv3Model,
"question-answering": LayoutLMv3ForQuestionAnswering,
"text-classification": LayoutLMv3ForSequenceClassification,
"token-classification": LayoutLMv3ForTokenClassification,
"zero-shot": LayoutLMv3ForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = LayoutLMv3ModelTester(self) self.model_tester = LayoutLMv3ModelTester(self)
......
...@@ -27,6 +27,7 @@ from transformers.utils import cached_property ...@@ -27,6 +27,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -263,7 +264,7 @@ class TFLayoutLMv3ModelTester: ...@@ -263,7 +264,7 @@ class TFLayoutLMv3ModelTester:
@require_tf @require_tf
class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase): class TFLayoutLMv3ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFLayoutLMv3Model, TFLayoutLMv3Model,
...@@ -274,6 +275,17 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -274,6 +275,17 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFLayoutLMv3Model,
"question-answering": TFLayoutLMv3ForQuestionAnswering,
"text-classification": TFLayoutLMv3ForSequenceClassification,
"token-classification": TFLayoutLMv3ForTokenClassification,
"zero-shot": TFLayoutLMv3ForSequenceClassification,
}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -27,6 +27,7 @@ from transformers.utils import cached_property ...@@ -27,6 +27,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -268,13 +269,26 @@ class LEDModelTester: ...@@ -268,13 +269,26 @@ class LEDModelTester:
@require_torch @require_torch
class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(LEDModel, LEDForConditionalGeneration, LEDForSequenceClassification, LEDForQuestionAnswering) (LEDModel, LEDForConditionalGeneration, LEDForSequenceClassification, LEDForQuestionAnswering)
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"conversational": LEDForConditionalGeneration,
"feature-extraction": LEDModel,
"question-answering": LEDForQuestionAnswering,
"summarization": LEDForConditionalGeneration,
"text2text-generation": LEDForConditionalGeneration,
"text-classification": LEDForSequenceClassification,
"zero-shot": LEDForSequenceClassification,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow, tooslow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -189,9 +190,19 @@ def prepare_led_inputs_dict( ...@@ -189,9 +190,19 @@ def prepare_led_inputs_dict(
@require_tf @require_tf
class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFLEDForConditionalGeneration, TFLEDModel) if is_tf_available() else () all_model_classes = (TFLEDForConditionalGeneration, TFLEDModel) if is_tf_available() else ()
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"conversational": TFLEDForConditionalGeneration,
"feature-extraction": TFLEDModel,
"summarization": TFLEDForConditionalGeneration,
"text2text-generation": TFLEDForConditionalGeneration,
}
if is_tf_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
......
...@@ -29,6 +29,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -29,6 +29,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -165,7 +166,7 @@ class LevitModelTester: ...@@ -165,7 +166,7 @@ class LevitModelTester:
@require_torch @require_torch
class LevitModelTest(ModelTesterMixin, unittest.TestCase): class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as Levit does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as Levit does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -176,6 +177,14 @@ class LevitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -176,6 +177,14 @@ class LevitModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": LevitModel,
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -218,7 +219,7 @@ class LiltModelTester: ...@@ -218,7 +219,7 @@ class LiltModelTester:
@require_torch @require_torch
class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
LiltModel, LiltModel,
...@@ -229,6 +230,17 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -229,6 +230,17 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": LiltModel,
"question-answering": LiltForQuestionAnswering,
"text-classification": LiltForSequenceClassification,
"token-classification": LiltForTokenClassification,
"zero-shot": LiltForSequenceClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -296,7 +297,7 @@ class LongformerModelTester: ...@@ -296,7 +297,7 @@ class LongformerModelTester:
@require_torch @require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported test_pruning = False # pruning is not supported
test_torchscript = False test_torchscript = False
...@@ -312,6 +313,18 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -312,6 +313,18 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": LongformerModel,
"fill-mask": LongformerForMaskedLM,
"question-answering": LongformerForQuestionAnswering,
"text-classification": LongformerForSequenceClassification,
"token-classification": LongformerForTokenClassification,
"zero-shot": LongformerForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = LongformerModelTester(self) self.model_tester = LongformerModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment