Unverified Commit e3d832ff authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix Blip-2 CI again (#21637)



* fix blip-2 ci

* fix blip-2 ci

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 762dda44
...@@ -799,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -799,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_opt_batched_beam_search(self): def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(torch_device)
# prepare image # prepare image
image = prepare_img() image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs, num_beams=2) predictions = model.generate(**inputs, num_beams=2)
...@@ -844,14 +846,16 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -844,14 +846,16 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_t5_batched_beam_search(self): def test_inference_t5_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
).to(torch_device)
# prepare image # prepare image
image = prepare_img() image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs, num_beams=2) predictions = model.generate(**inputs, num_beams=2)
# Test output (in this case, slightly different from greedy search) # Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1]) self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1]) self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment