Unverified Commit 297d769d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Better test name and enable pipeline test for `pix2struct` (#24377)



* best test name forever

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6950f70b
...@@ -35,6 +35,7 @@ from ...test_modeling_common import ( ...@@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -354,7 +355,7 @@ class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -354,7 +355,7 @@ class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
class Pix2StructTextImageModelsModelTester: class Pix2StructModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None: if text_kwargs is None:
text_kwargs = {} text_kwargs = {}
...@@ -394,8 +395,9 @@ class Pix2StructTextImageModelsModelTester: ...@@ -394,8 +395,9 @@ class Pix2StructTextImageModelsModelTester:
@require_torch @require_torch
class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
...@@ -404,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -404,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = False test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = Pix2StructTextImageModelsModelTester(self) self.model_tester = Pix2StructModelTester(self)
def test_model(self): def test_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment