Unverified Commit 871c31a6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

🔥Rework pipeline testing by removing `PipelineTestCaseMeta` 🚀 (#21516)



* Add PipelineTesterMixin

* remove class PipelineTestCaseMeta

* move validate_test_components

* Add for ViT

* Add to SPECIAL_MODULE_TO_TEST_MAP

* style and quality

* Add feature-extraction

* update

* raise instead of skip

* add tiny_model_summary.json

* more explicit

* skip tasks not in mapping

* add availability check

* Add Copyright

* A way to diable irrelevant tests

* update with main

* remove disable_irrelevant_tests

* skip tests

* better skip message

* better skip message

* Add all pipeline task tests

* revert

* Import PipelineTesterMixin

* subclass test classes with PipelineTesterMixin

* Add pipieline_model_mapping

* Fix import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix one more import after adding pipieline_model_mapping

* Fix style and quality after adding pipieline_model_mapping

* Fix test issues

* Fix import requirements

* Fix mapping for MobileViTModelTest

* Update

* Better skip message

* pipieline_model_mapping could not be None

* Remove some PipelineTesterMixin

* Fix typo

* revert tests_fetcher.py

* update

* rename

* revert

* Remove PipelineTestCaseMeta from ZeroShotAudioClassificationPipelineTests

* style and quality

* test fetcher for all pipeline/model tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4cb5ffa9
...@@ -25,6 +25,7 @@ from transformers.utils import cached_property ...@@ -25,6 +25,7 @@ from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -186,10 +187,11 @@ class BridgeTowerModelTester: ...@@ -186,10 +187,11 @@ class BridgeTowerModelTester:
@require_torch @require_torch
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") @unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
class BridgeTowerModelTest(ModelTesterMixin, unittest.TestCase): class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () (BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
) )
pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {}
is_training = False is_training = False
test_headmasking = False test_headmasking = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, global_rng, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, _config_zero_init, global_rng, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -207,7 +208,7 @@ class CanineModelTester: ...@@ -207,7 +208,7 @@ class CanineModelTester:
@require_torch @require_torch
class CanineModelTest(ModelTesterMixin, unittest.TestCase): class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
CanineModel, CanineModel,
...@@ -219,6 +220,17 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -219,6 +220,17 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": CanineModel,
"question-answering": CanineForQuestionAnswering,
"text-classification": CanineForSequenceClassification,
"token-classification": CanineForTokenClassification,
"zero-shot": CanineForSequenceClassification,
}
if is_torch_available()
else {}
)
test_mismatched_shapes = False test_mismatched_shapes = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -35,6 +35,7 @@ from ...test_modeling_common import ( ...@@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -533,8 +534,9 @@ class ChineseCLIPModelTester: ...@@ -533,8 +534,9 @@ class ChineseCLIPModelTester:
@require_torch @require_torch
class ChineseCLIPModelTest(ModelTesterMixin, unittest.TestCase): class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ChineseCLIPModel,) if is_torch_available() else () all_model_classes = (ChineseCLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": ChineseCLIPModel} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -35,6 +35,7 @@ from ...test_modeling_common import ( ...@@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -477,8 +478,9 @@ class ClapModelTester: ...@@ -477,8 +478,9 @@ class ClapModelTester:
@require_torch @require_torch
class ClapModelTest(ModelTesterMixin, unittest.TestCase): class ClapModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ClapModel,) if is_torch_available() else () all_model_classes = (ClapModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": ClapModel} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -43,6 +43,7 @@ from ...test_modeling_common import ( ...@@ -43,6 +43,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -449,8 +450,9 @@ class CLIPModelTester: ...@@ -449,8 +450,9 @@ class CLIPModelTester:
@require_torch @require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase): class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else () all_model_classes = (CLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": CLIPModel} if is_torch_available() else {}
fx_compatible = True fx_compatible = True
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -29,6 +29,7 @@ from transformers.utils import is_tf_available, is_vision_available ...@@ -29,6 +29,7 @@ from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -515,8 +516,9 @@ class TFCLIPModelTester: ...@@ -515,8 +516,9 @@ class TFCLIPModelTester:
@require_tf @require_tf
class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase): class TFCLIPModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFCLIPModel,) if is_tf_available() else () all_model_classes = (TFCLIPModel,) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFCLIPModel} if is_tf_available() else {}
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -44,6 +44,7 @@ from ...test_modeling_common import ( ...@@ -44,6 +44,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -412,8 +413,9 @@ class CLIPSegModelTester: ...@@ -412,8 +413,9 @@ class CLIPSegModelTester:
@require_torch @require_torch
class CLIPSegModelTest(ModelTesterMixin, unittest.TestCase): class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPSegModel, CLIPSegForImageSegmentation) if is_torch_available() else () all_model_classes = (CLIPSegModel, CLIPSegForImageSegmentation) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": CLIPSegModel} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -348,9 +349,12 @@ class CodeGenModelTester: ...@@ -348,9 +349,12 @@ class CodeGenModelTester:
@require_torch @require_torch
class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CodeGenModel, CodeGenForCausalLM) if is_torch_available() else () all_model_classes = (CodeGenModel, CodeGenForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (CodeGenForCausalLM,) if is_torch_available() else () all_generative_model_classes = (CodeGenForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": CodeGenModel, "text-generation": CodeGenForCausalLM} if is_torch_available() else {}
)
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -26,6 +26,7 @@ from transformers.utils import cached_property ...@@ -26,6 +26,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_timm_available(): if is_timm_available():
...@@ -179,7 +180,7 @@ class ConditionalDetrModelTester: ...@@ -179,7 +180,7 @@ class ConditionalDetrModelTester:
@require_timm @require_timm
class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ConditionalDetrModel, ConditionalDetrModel,
...@@ -189,6 +190,11 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest ...@@ -189,6 +190,11 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
if is_timm_available() if is_timm_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
if is_timm_available()
else {}
)
is_encoder_decoder = True is_encoder_decoder = True
test_torchscript = False test_torchscript = False
test_pruning = False test_pruning = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t ...@@ -23,6 +23,7 @@ from transformers.testing_utils import require_torch, require_torch_gpu, slow, t
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -245,7 +246,7 @@ class ConvBertModelTester: ...@@ -245,7 +246,7 @@ class ConvBertModelTester:
@require_torch @require_torch
class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ConvBertModel, ConvBertModel,
...@@ -258,6 +259,18 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -258,6 +259,18 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": ConvBertModel,
"fill-mask": ConvBertForMaskedLM,
"question-answering": ConvBertForQuestionAnswering,
"text-classification": ConvBertForSequenceClassification,
"token-classification": ConvBertForTokenClassification,
"zero-shot": ConvBertForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -223,7 +224,7 @@ class TFConvBertModelTester: ...@@ -223,7 +224,7 @@ class TFConvBertModelTester:
@require_tf @require_tf
class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase): class TFConvBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFConvBertModel, TFConvBertModel,
...@@ -236,6 +237,18 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -236,6 +237,18 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": TFConvBertModel,
"fill-mask": TFConvBertForMaskedLM,
"question-answering": TFConvBertForQuestionAnswering,
"text-classification": TFConvBertForSequenceClassification,
"token-classification": TFConvBertForTokenClassification,
"zero-shot": TFConvBertForSequenceClassification,
}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -152,7 +153,7 @@ class ConvNextModelTester: ...@@ -152,7 +153,7 @@ class ConvNextModelTester:
@require_torch @require_torch
class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -167,6 +168,11 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -167,6 +168,11 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{"feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail ...@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -117,13 +118,18 @@ class TFConvNextModelTester: ...@@ -117,13 +118,18 @@ class TFConvNextModelTester:
@require_tf @require_tf
class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase): class TFConvNextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else () all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFConvNextModel, "image-classification": TFConvNextForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_onnx = False test_onnx = False
......
...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device ...@@ -22,6 +22,7 @@ from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -192,9 +193,19 @@ class CTRLModelTester: ...@@ -192,9 +193,19 @@ class CTRLModelTester:
@require_torch @require_torch
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else () all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": CTRLModel,
"text-classification": CTRLForSequenceClassification,
"text-generation": CTRLLMHeadModel,
"zero-shot": CTRLForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = True test_pruning = True
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow ...@@ -21,6 +21,7 @@ from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -168,9 +169,19 @@ class TFCTRLModelTester(object): ...@@ -168,9 +169,19 @@ class TFCTRLModelTester(object):
@require_tf @require_tf
class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else () all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": TFCTRLModel,
"text-classification": TFCTRLForSequenceClassification,
"text-generation": TFCTRLLMHeadModel,
"zero-shot": TFCTRLForSequenceClassification,
}
if is_tf_available()
else {}
)
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
......
...@@ -25,6 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -25,6 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -143,13 +144,18 @@ class CvtModelTester: ...@@ -143,13 +144,18 @@ class CvtModelTester:
@require_torch @require_torch
class CvtModelTest(ModelTesterMixin, unittest.TestCase): class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as Cvt does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as Cvt does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else () all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
......
...@@ -13,6 +13,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail ...@@ -13,6 +13,7 @@ from transformers.utils import cached_property, is_tf_available, is_vision_avail
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
...@@ -128,13 +129,18 @@ class TFCvtModelTester: ...@@ -128,13 +129,18 @@ class TFCvtModelTester:
@require_tf @require_tf
class TFCvtModelTest(TFModelTesterMixin, unittest.TestCase): class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as Cvt Here we also overwrite some of the tests of test_modeling_common.py, as Cvt
does not use input_ids, inputs_embeds, attention_mask and seq_length. does not use input_ids, inputs_embeds, attention_mask and seq_length.
""" """
all_model_classes = (TFCvtModel, TFCvtForImageClassification) if is_tf_available() else () all_model_classes = (TFCvtModel, TFCvtForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFCvtModel, "image-classification": TFCvtForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -26,6 +26,7 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, ...@@ -26,6 +26,7 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile,
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init from ...test_modeling_common import ModelTesterMixin, _config_zero_init
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -358,7 +359,7 @@ class Data2VecAudioModelTester: ...@@ -358,7 +359,7 @@ class Data2VecAudioModelTester:
@require_torch @require_torch
class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
Data2VecAudioForCTC, Data2VecAudioForCTC,
...@@ -370,6 +371,15 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -370,6 +371,15 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"audio-classification": Data2VecAudioForSequenceClassification,
"automatic-speech-recognition": Data2VecAudioForCTC,
"feature-extraction": Data2VecAudioModel,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
......
...@@ -23,6 +23,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_ ...@@ -23,6 +23,7 @@ from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin from ...test_modeling_common import ModelTesterMixin
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -359,7 +360,7 @@ class Data2VecTextModelTester: ...@@ -359,7 +360,7 @@ class Data2VecTextModelTester:
@require_torch @require_torch
class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
Data2VecTextForCausalLM, Data2VecTextForCausalLM,
...@@ -374,6 +375,19 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te ...@@ -374,6 +375,19 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
else () else ()
) )
all_generative_model_classes = (Data2VecTextForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Data2VecTextForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": Data2VecTextModel,
"fill-mask": Data2VecTextForMaskedLM,
"question-answering": Data2VecTextForQuestionAnswering,
"text-classification": Data2VecTextForSequenceClassification,
"text-generation": Data2VecTextForCausalLM,
"token-classification": Data2VecTextForTokenClassification,
"zero-shot": Data2VecTextForSequenceClassification,
}
if is_torch_available()
else {}
)
def setUp(self): def setUp(self):
self.model_tester = Data2VecTextModelTester(self) self.model_tester = Data2VecTextModelTester(self)
......
...@@ -25,6 +25,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av ...@@ -25,6 +25,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -165,7 +166,7 @@ class Data2VecVisionModelTester: ...@@ -165,7 +166,7 @@ class Data2VecVisionModelTester:
@require_torch @require_torch
class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase): class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as Data2VecVision does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as Data2VecVision does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -176,6 +177,15 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -176,6 +177,15 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = (
{
"feature-extraction": Data2VecVisionModel,
"image-classification": Data2VecVisionForImageClassification,
"image-segmentation": Data2VecVisionForSemanticSegmentation,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment