Unverified Commit edc1e734 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix Blip-2 CI (#21595)



* use fp16

* use fp16

* use fp16

* use fp16

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent fd5320bb
...@@ -768,11 +768,13 @@ def prepare_img(): ...@@ -768,11 +768,13 @@ def prepare_img():
class Blip2ModelIntegrationTest(unittest.TestCase): class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_opt(self): def test_inference_opt(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(torch_device)
# prepare image # prepare image
image = prepare_img() image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device) inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
...@@ -783,7 +785,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -783,7 +785,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# image and context # image and context
prompt = "Question: which city is this? Answer:" prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
...@@ -797,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -797,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_t5(self): def test_inference_t5(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
).to(torch_device)
# prepare image # prepare image
image = prepare_img() image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device) inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
...@@ -812,7 +816,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -812,7 +816,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# image and context # image and context
prompt = "Question: which city is this? Answer:" prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment