Unverified Commit e577bd0f authored by Matt's avatar Matt Committed by GitHub
Browse files

Use native TF checkpoints for the BLIP TF tests (#22593)

* Use native TF checkpoints for the TF tests

* Remove unneeded exceptions
parent 176ceff9
...@@ -189,10 +189,7 @@ class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -189,10 +189,7 @@ class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipVisionModel.from_pretrained(model_name) model = TFBlipVisionModel.from_pretrained(model_name)
except OSError:
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -320,10 +317,7 @@ class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -320,10 +317,7 @@ class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name) model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
...@@ -432,7 +426,7 @@ class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -432,7 +426,7 @@ class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True) model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
...@@ -635,7 +629,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -635,7 +629,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True) model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="Tested in individual model tests") @unittest.skip(reason="Tested in individual model tests")
...@@ -750,10 +744,7 @@ class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -750,10 +744,7 @@ class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipModel.from_pretrained(model_name) model = TFBlipModel.from_pretrained(model_name)
except OSError:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -769,7 +760,7 @@ def prepare_img(): ...@@ -769,7 +760,7 @@ def prepare_img():
@slow @slow
class TFBlipModelIntegrationTest(unittest.TestCase): class TFBlipModelIntegrationTest(unittest.TestCase):
def test_inference_image_captioning(self): def test_inference_image_captioning(self):
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True) model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
image = prepare_img() image = prepare_img()
...@@ -796,7 +787,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase): ...@@ -796,7 +787,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
) )
def test_inference_vqa(self): def test_inference_vqa(self):
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True) model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
image = prepare_img() image = prepare_img()
...@@ -808,7 +799,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase): ...@@ -808,7 +799,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102]) self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
def test_inference_itm(self): def test_inference_itm(self):
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True) model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco") processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
image = prepare_img() image = prepare_img()
......
...@@ -160,10 +160,7 @@ class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -160,10 +160,7 @@ class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name) model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True): def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment