Unverified Commit fa01127a authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

update_pip_test_mapping (#22606)



* Add TFBlipForConditionalGeneration

* update pipeline_model_mapping

* Add import

* Revert changes in GPTSanJapaneseTest

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 321b0908
...@@ -239,9 +239,10 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -239,9 +239,10 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"fill-mask": MBartForConditionalGeneration, "fill-mask": MBartForConditionalGeneration,
"question-answering": MBartForQuestionAnswering, "question-answering": MBartForQuestionAnswering,
"summarization": MBartForConditionalGeneration, "summarization": MBartForConditionalGeneration,
"text2text-generation": MBartForConditionalGeneration,
"text-classification": MBartForSequenceClassification, "text-classification": MBartForSequenceClassification,
"text-generation": MBartForCausalLM, "text-generation": MBartForCausalLM,
"text2text-generation": MBartForConditionalGeneration,
"translation": MBartForConditionalGeneration,
"zero-shot": MBartForSequenceClassification, "zero-shot": MBartForSequenceClassification,
} }
if is_torch_available() if is_torch_available()
......
...@@ -190,6 +190,7 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas ...@@ -190,6 +190,7 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
"feature-extraction": TFMBartModel, "feature-extraction": TFMBartModel,
"summarization": TFMBartForConditionalGeneration, "summarization": TFMBartForConditionalGeneration,
"text2text-generation": TFMBartForConditionalGeneration, "text2text-generation": TFMBartForConditionalGeneration,
"translation": TFMBartForConditionalGeneration,
} }
if is_tf_available() if is_tf_available()
else {} else {}
......
...@@ -471,9 +471,11 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -471,9 +471,11 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MegaModel, "feature-extraction": MegaModel,
"fill-mask": MegaForMaskedLM,
"question-answering": MegaForQuestionAnswering, "question-answering": MegaForQuestionAnswering,
"text-classification": MegaForSequenceClassification, "text-classification": MegaForSequenceClassification,
"text-generation": MegaForCausalLM, "text-generation": MegaForCausalLM,
"token-classification": MegaForTokenClassification,
"zero-shot": MegaForSequenceClassification, "zero-shot": MegaForSequenceClassification,
} }
if is_torch_available() if is_torch_available()
......
...@@ -25,6 +25,7 @@ from transformers.utils import is_torch_available, is_vision_available ...@@ -25,6 +25,7 @@ from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -116,8 +117,9 @@ class MgpstrModelTester: ...@@ -116,8 +117,9 @@ class MgpstrModelTester:
@require_torch @require_torch
class MgpstrModelTest(ModelTesterMixin, unittest.TestCase): class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else () all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -420,9 +420,10 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -420,9 +420,10 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"fill-mask": MvpForConditionalGeneration, "fill-mask": MvpForConditionalGeneration,
"question-answering": MvpForQuestionAnswering, "question-answering": MvpForQuestionAnswering,
"summarization": MvpForConditionalGeneration, "summarization": MvpForConditionalGeneration,
"text2text-generation": MvpForConditionalGeneration,
"text-classification": MvpForSequenceClassification, "text-classification": MvpForSequenceClassification,
"text-generation": MvpForCausalLM, "text-generation": MvpForCausalLM,
"text2text-generation": MvpForConditionalGeneration,
"translation": MvpForConditionalGeneration,
"zero-shot": MvpForSequenceClassification, "zero-shot": MvpForSequenceClassification,
} }
if is_torch_available() if is_torch_available()
......
...@@ -255,6 +255,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -255,6 +255,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"feature-extraction": NllbMoeModel, "feature-extraction": NllbMoeModel,
"summarization": NllbMoeForConditionalGeneration, "summarization": NllbMoeForConditionalGeneration,
"text2text-generation": NllbMoeForConditionalGeneration, "text2text-generation": NllbMoeForConditionalGeneration,
"translation": NllbMoeForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -242,8 +242,9 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -242,8 +242,9 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"conversational": PegasusForConditionalGeneration, "conversational": PegasusForConditionalGeneration,
"feature-extraction": PegasusModel, "feature-extraction": PegasusModel,
"summarization": PegasusForConditionalGeneration, "summarization": PegasusForConditionalGeneration,
"text2text-generation": PegasusForConditionalGeneration,
"text-generation": PegasusForCausalLM, "text-generation": PegasusForCausalLM,
"text2text-generation": PegasusForConditionalGeneration,
"translation": PegasusForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -185,6 +185,7 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -185,6 +185,7 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
"feature-extraction": TFPegasusModel, "feature-extraction": TFPegasusModel,
"summarization": TFPegasusForConditionalGeneration, "summarization": TFPegasusForConditionalGeneration,
"text2text-generation": TFPegasusForConditionalGeneration, "text2text-generation": TFPegasusForConditionalGeneration,
"translation": TFPegasusForConditionalGeneration,
} }
if is_tf_available() if is_tf_available()
else {} else {}
......
...@@ -204,6 +204,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -204,6 +204,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
"feature-extraction": PegasusXModel, "feature-extraction": PegasusXModel,
"summarization": PegasusXForConditionalGeneration, "summarization": PegasusXForConditionalGeneration,
"text2text-generation": PegasusXForConditionalGeneration, "text2text-generation": PegasusXForConditionalGeneration,
"translation": PegasusXForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -224,9 +224,10 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -224,9 +224,10 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
"conversational": PLBartForConditionalGeneration, "conversational": PLBartForConditionalGeneration,
"feature-extraction": PLBartModel, "feature-extraction": PLBartModel,
"summarization": PLBartForConditionalGeneration, "summarization": PLBartForConditionalGeneration,
"text2text-generation": PLBartForConditionalGeneration,
"text-classification": PLBartForSequenceClassification, "text-classification": PLBartForSequenceClassification,
"text-generation": PLBartForCausalLM, "text-generation": PLBartForCausalLM,
"text2text-generation": PLBartForConditionalGeneration,
"translation": PLBartForConditionalGeneration,
"zero-shot": PLBartForSequenceClassification, "zero-shot": PLBartForSequenceClassification,
} }
if is_torch_available() if is_torch_available()
......
...@@ -894,8 +894,9 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -894,8 +894,9 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
"conversational": ProphetNetForConditionalGeneration, "conversational": ProphetNetForConditionalGeneration,
"feature-extraction": ProphetNetModel, "feature-extraction": ProphetNetModel,
"summarization": ProphetNetForConditionalGeneration, "summarization": ProphetNetForConditionalGeneration,
"text2text-generation": ProphetNetForConditionalGeneration,
"text-generation": ProphetNetForCausalLM, "text-generation": ProphetNetForCausalLM,
"text2text-generation": ProphetNetForConditionalGeneration,
"translation": ProphetNetForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel ...@@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
"feature-extraction": SwitchTransformersModel, "feature-extraction": SwitchTransformersModel,
"summarization": SwitchTransformersForConditionalGeneration, "summarization": SwitchTransformersForConditionalGeneration,
"text2text-generation": SwitchTransformersForConditionalGeneration, "text2text-generation": SwitchTransformersForConditionalGeneration,
"translation": SwitchTransformersForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -528,6 +528,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -528,6 +528,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"feature-extraction": T5Model, "feature-extraction": T5Model,
"summarization": T5ForConditionalGeneration, "summarization": T5ForConditionalGeneration,
"text2text-generation": T5ForConditionalGeneration, "text2text-generation": T5ForConditionalGeneration,
"translation": T5ForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
......
...@@ -250,6 +250,7 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -250,6 +250,7 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"feature-extraction": TFT5Model, "feature-extraction": TFT5Model,
"summarization": TFT5ForConditionalGeneration, "summarization": TFT5ForConditionalGeneration,
"text2text-generation": TFT5ForConditionalGeneration, "text2text-generation": TFT5ForConditionalGeneration,
"translation": TFT5ForConditionalGeneration,
} }
if is_tf_available() if is_tf_available()
else {} else {}
......
...@@ -277,7 +277,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -277,7 +277,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else () all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"automatic-speech-recognition": WhisperForConditionalGeneration, "feature-extraction": WhisperModel} {
"audio-classification": WhisperForAudioClassification,
"automatic-speech-recognition": WhisperForConditionalGeneration,
"feature-extraction": WhisperModel,
}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment