Unverified Commit fa01127a authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

update_pip_test_mapping (#22606)



* Add TFBlipForConditionalGeneration

* update pipeline_model_mapping

* Add import

* Revert changes in GPTSanJapaneseTest

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 321b0908
......@@ -231,6 +231,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "TFBlipForConditionalGeneration"),
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
......
......@@ -40,6 +40,7 @@ from ...test_modeling_common import (
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -420,8 +421,9 @@ class AlignModelTester:
@require_torch
class AlignModelTest(ModelTesterMixin, unittest.TestCase):
class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (AlignModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": AlignModel} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
......
......@@ -429,9 +429,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
"fill-mask": BartForConditionalGeneration,
"question-answering": BartForQuestionAnswering,
"summarization": BartForConditionalGeneration,
"text2text-generation": BartForConditionalGeneration,
"text-classification": BartForSequenceClassification,
"text-generation": BartForCausalLM,
"text2text-generation": BartForConditionalGeneration,
"translation": BartForConditionalGeneration,
"zero-shot": BartForSequenceClassification,
}
if is_torch_available()
......
......@@ -199,8 +199,9 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
"conversational": TFBartForConditionalGeneration,
"feature-extraction": TFBartModel,
"summarization": TFBartForConditionalGeneration,
"text2text-generation": TFBartForConditionalGeneration,
"text-classification": TFBartForSequenceClassification,
"text2text-generation": TFBartForConditionalGeneration,
"translation": TFBartForConditionalGeneration,
"zero-shot": TFBartForSequenceClassification,
}
if is_tf_available()
......
......@@ -251,9 +251,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
"feature-extraction": BigBirdPegasusModel,
"question-answering": BigBirdPegasusForQuestionAnswering,
"summarization": BigBirdPegasusForConditionalGeneration,
"text2text-generation": BigBirdPegasusForConditionalGeneration,
"text-classification": BigBirdPegasusForSequenceClassification,
"text-generation": BigBirdPegasusForCausalLM,
"text2text-generation": BigBirdPegasusForConditionalGeneration,
"translation": BigBirdPegasusForConditionalGeneration,
"zero-shot": BigBirdPegasusForSequenceClassification,
}
if is_torch_available()
......
......@@ -232,8 +232,9 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
"conversational": BlenderbotForConditionalGeneration,
"feature-extraction": BlenderbotModel,
"summarization": BlenderbotForConditionalGeneration,
"text2text-generation": BlenderbotForConditionalGeneration,
"text-generation": BlenderbotForCausalLM,
"text2text-generation": BlenderbotForConditionalGeneration,
"translation": BlenderbotForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -185,6 +185,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
"feature-extraction": TFBlenderbotModel,
"summarization": TFBlenderbotForConditionalGeneration,
"text2text-generation": TFBlenderbotForConditionalGeneration,
"translation": TFBlenderbotForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -226,8 +226,9 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
"conversational": BlenderbotSmallForConditionalGeneration,
"feature-extraction": BlenderbotSmallModel,
"summarization": BlenderbotSmallForConditionalGeneration,
"text2text-generation": BlenderbotSmallForConditionalGeneration,
"text-generation": BlenderbotSmallForCausalLM,
"text2text-generation": BlenderbotSmallForConditionalGeneration,
"translation": BlenderbotSmallForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -187,6 +187,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
"feature-extraction": TFBlenderbotSmallModel,
"summarization": TFBlenderbotSmallForConditionalGeneration,
"text2text-generation": TFBlenderbotSmallForConditionalGeneration,
"translation": TFBlenderbotSmallForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -164,7 +165,7 @@ class ConvNextV2ModelTester:
@require_torch
class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase):
class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNextV2 does not use input_ids, inputs_embeds,
attention_mask and seq_length.
......@@ -179,6 +180,11 @@ class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = False
test_pruning = False
......
......@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
......@@ -122,13 +123,18 @@ class EfficientNetModelTester:
@require_torch
class EfficientNetModelTest(ModelTesterMixin, unittest.TestCase):
class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as EfficientNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = False
test_pruning = False
......
......@@ -163,6 +163,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
"feature-extraction": FSMTModel,
"summarization": FSMTForConditionalGeneration,
"text2text-generation": FSMTForConditionalGeneration,
"translation": FSMTForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -96,7 +96,7 @@ class GPTSanJapaneseTester:
def get_config(self):
return GPTSanJapaneseConfig(
vocab_size=36000,
vocab_size=self.vocab_size,
num_contexts=self.seq_length,
d_model=self.hidden_size,
d_ff=self.d_ff,
......
......@@ -26,6 +26,7 @@ from transformers.testing_utils import is_flaky, require_torch, slow, torch_devi
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
TOLERANCE = 1e-4
......@@ -177,9 +178,10 @@ class InformerModelTester:
@require_torch
class InformerModelTest(ModelTesterMixin, unittest.TestCase):
class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": InformerModel} if is_torch_available() else {}
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
......
......@@ -282,8 +282,9 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"feature-extraction": LEDModel,
"question-answering": LEDForQuestionAnswering,
"summarization": LEDForConditionalGeneration,
"text2text-generation": LEDForConditionalGeneration,
"text-classification": LEDForSequenceClassification,
"text2text-generation": LEDForConditionalGeneration,
"translation": LEDForConditionalGeneration,
"zero-shot": LEDForSequenceClassification,
}
if is_torch_available()
......
......@@ -199,6 +199,7 @@ class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
"feature-extraction": TFLEDModel,
"summarization": TFLEDForConditionalGeneration,
"text2text-generation": TFLEDForConditionalGeneration,
"translation": TFLEDForConditionalGeneration,
}
if is_tf_available()
else {}
......
......@@ -510,6 +510,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
"feature-extraction": LongT5Model,
"summarization": LongT5ForConditionalGeneration,
"text2text-generation": LongT5ForConditionalGeneration,
"translation": LongT5ForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -237,6 +237,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
"feature-extraction": M2M100Model,
"summarization": M2M100ForConditionalGeneration,
"text2text-generation": M2M100ForConditionalGeneration,
"translation": M2M100ForConditionalGeneration,
}
if is_torch_available()
else {}
......
......@@ -244,8 +244,9 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
"conversational": MarianMTModel,
"feature-extraction": MarianModel,
"summarization": MarianMTModel,
"text2text-generation": MarianMTModel,
"text-generation": MarianForCausalLM,
"text2text-generation": MarianMTModel,
"translation": MarianMTModel,
}
if is_torch_available()
else {}
......
......@@ -187,6 +187,7 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
"feature-extraction": TFMarianModel,
"summarization": TFMarianMTModel,
"text2text-generation": TFMarianMTModel,
"translation": TFMarianMTModel,
}
if is_tf_available()
else {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment