Unverified Commit fa01127a authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

update_pip_test_mapping (#22606)



* Add TFBlipForConditionalGeneration

* update pipeline_model_mapping

* Add import

* Revert changes in GPTSanJapaneseTest

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 321b0908
......@@ -239,9 +239,10 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"fill-mask": MBartForConditionalGeneration,
"question-answering": MBartForQuestionAnswering,
"summarization": MBartForConditionalGeneration,
"text2text-generation": MBartForConditionalGeneration,
"text-classification": MBartForSequenceClassification,
"text-generation": MBartForCausalLM,
"text2text-generation": MBartForConditionalGeneration,
"translation": MBartForConditionalGeneration,
"zero-shot": MBartForSequenceClassification,
}
if is_torch_available()
......
......@@ -190,6 +190,7 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
"feature-extraction": TFMBartModel,
"summarization": TFMBartForConditionalGeneration,
"text2text-generation": TFMBartForConditionalGeneration,
"translation": TFMBartForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -471,9 +471,11 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping = (
{
"feature-extraction": MegaModel,
"fill-mask": MegaForMaskedLM,
"question-answering": MegaForQuestionAnswering,
"text-classification": MegaForSequenceClassification,
"text-generation": MegaForCausalLM,
"token-classification": MegaForTokenClassification,
"zero-shot": MegaForSequenceClassification,
}
if is_torch_available()
......
......@@ -25,6 +25,7 @@ from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -116,8 +117,9 @@ class MgpstrModelTester:
@require_torch
class MgpstrModelTest(ModelTesterMixin, unittest.TestCase):
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
fx_compatible = False
test_pruning = False
......
......@@ -420,9 +420,10 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"fill-mask": MvpForConditionalGeneration,
"question-answering": MvpForQuestionAnswering,
"summarization": MvpForConditionalGeneration,
"text2text-generation": MvpForConditionalGeneration,
"text-classification": MvpForSequenceClassification,
"text-generation": MvpForCausalLM,
"text2text-generation": MvpForConditionalGeneration,
"translation": MvpForConditionalGeneration,
"zero-shot": MvpForSequenceClassification,
}
if is_torch_available()
......
......@@ -255,6 +255,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"feature-extraction": NllbMoeModel,
"summarization": NllbMoeForConditionalGeneration,
"text2text-generation": NllbMoeForConditionalGeneration,
"translation": NllbMoeForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -242,8 +242,9 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"conversational": PegasusForConditionalGeneration,
"feature-extraction": PegasusModel,
"summarization": PegasusForConditionalGeneration,
"text2text-generation": PegasusForConditionalGeneration,
"text-generation": PegasusForCausalLM,
"text2text-generation": PegasusForConditionalGeneration,
"translation": PegasusForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -185,6 +185,7 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
"feature-extraction": TFPegasusModel,
"summarization": TFPegasusForConditionalGeneration,
"text2text-generation": TFPegasusForConditionalGeneration,
"translation": TFPegasusForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -204,6 +204,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
"feature-extraction": PegasusXModel,
"summarization": PegasusXForConditionalGeneration,
"text2text-generation": PegasusXForConditionalGeneration,
"translation": PegasusXForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -224,9 +224,10 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
"conversational": PLBartForConditionalGeneration,
"feature-extraction": PLBartModel,
"summarization": PLBartForConditionalGeneration,
"text2text-generation": PLBartForConditionalGeneration,
"text-classification": PLBartForSequenceClassification,
"text-generation": PLBartForCausalLM,
"text2text-generation": PLBartForConditionalGeneration,
"translation": PLBartForConditionalGeneration,
"zero-shot": PLBartForSequenceClassification,
}
if is_torch_available()
......
......@@ -894,8 +894,9 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
"conversational": ProphetNetForConditionalGeneration,
"feature-extraction": ProphetNetModel,
"summarization": ProphetNetForConditionalGeneration,
"text2text-generation": ProphetNetForConditionalGeneration,
"text-generation": ProphetNetForCausalLM,
"text2text-generation": ProphetNetForConditionalGeneration,
"translation": ProphetNetForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
"feature-extraction": SwitchTransformersModel,
"summarization": SwitchTransformersForConditionalGeneration,
"text2text-generation": SwitchTransformersForConditionalGeneration,
"translation": SwitchTransformersForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -528,6 +528,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"feature-extraction": T5Model,
"summarization": T5ForConditionalGeneration,
"text2text-generation": T5ForConditionalGeneration,
"translation": T5ForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -250,6 +250,7 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"feature-extraction": TFT5Model,
"summarization": TFT5ForConditionalGeneration,
"text2text-generation": TFT5ForConditionalGeneration,
"translation": TFT5ForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -277,7 +277,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{"automatic-speech-recognition": WhisperForConditionalGeneration, "feature-extraction": WhisperModel}
{
"audio-classification": WhisperForAudioClassification,
"automatic-speech-recognition": WhisperForConditionalGeneration,
"feature-extraction": WhisperModel,
}
if is_torch_available()
else {}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment