"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "592f2eabd17cbdebd13dec54edf412f9f8232152"
Unverified Commit 09dc9951 authored by Juan Pizarro's avatar Juan Pizarro Committed by GitHub
Browse files

Add Blip2 model in VQA pipeline (#25532)

* Add Blip2 model in VQA pipeline

* use require_torch_gpu for test_large_model_pt_blip2

* use can_generate in vqa pipeline

* test Blip2ForConditionalGeneration using float16

* remove custom can_generate from Blip2ForConditionalGeneration
parent 62399d6f
...@@ -839,6 +839,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -839,6 +839,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
("blip-2", "Blip2ForConditionalGeneration"),
("vilt", "ViltForQuestionAnswering"), ("vilt", "ViltForQuestionAnswering"),
] ]
) )
......
...@@ -124,19 +124,28 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -124,19 +124,28 @@ class VisualQuestionAnsweringPipeline(Pipeline):
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs) if self.model.can_generate():
model_outputs = self.model.generate(**model_inputs)
else:
model_outputs = self.model(**model_inputs)
return model_outputs return model_outputs
def postprocess(self, model_outputs, top_k=5): def postprocess(self, model_outputs, top_k=5):
if top_k > self.model.config.num_labels: if self.model.can_generate():
top_k = self.model.config.num_labels return [
{"answer": self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()}
if self.framework == "pt": for output_ids in model_outputs
probs = model_outputs.logits.sigmoid()[0] ]
scores, ids = probs.topk(top_k)
else: else:
raise ValueError(f"Unsupported framework: {self.framework}") if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
scores = scores.tolist()
ids = ids.tolist() if self.framework == "pt":
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] probs = model_outputs.logits.sigmoid()[0]
scores, ids = probs.topk(top_k)
else:
raise ValueError(f"Unsupported framework: {self.framework}")
scores = scores.tolist()
ids = ids.tolist()
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
...@@ -18,9 +18,11 @@ from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_ ...@@ -18,9 +18,11 @@ from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_
from transformers.pipelines import pipeline from transformers.pipelines import pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
is_torch_available,
nested_simplify, nested_simplify,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu,
require_vision, require_vision,
slow, slow,
) )
...@@ -28,6 +30,10 @@ from transformers.testing_utils import ( ...@@ -28,6 +30,10 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
if is_torch_available():
import torch
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
else: else:
...@@ -84,6 +90,37 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -84,6 +90,37 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}] outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
) )
@require_torch
@require_torch_gpu
def test_small_model_pt_blip2(self):
vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
)
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
question = "How many cats are there?"
outputs = vqa_pipeline(image=image, question=question)
self.assertEqual(outputs, [{"answer": ANY(str)}])
outputs = vqa_pipeline({"image": image, "question": question})
self.assertEqual(outputs, [{"answer": ANY(str)}])
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
self.assertEqual(outputs, [[{"answer": ANY(str)}]] * 2)
vqa_pipeline = pipeline(
"visual-question-answering",
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
model_kwargs={"torch_dtype": torch.float16},
device=0,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
outputs = vqa_pipeline(image=image, question=question)
self.assertEqual(outputs, [{"answer": ANY(str)}])
@slow @slow
@require_torch @require_torch
def test_large_model_pt(self): def test_large_model_pt(self):
...@@ -109,6 +146,31 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -109,6 +146,31 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2, [[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
) )
@slow
@require_torch
@require_torch_gpu
def test_large_model_pt_blip2(self):
vqa_pipeline = pipeline(
"visual-question-answering",
model="Salesforce/blip2-opt-2.7b",
model_kwargs={"torch_dtype": torch.float16},
device=0,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
question = "Question: how many cats are there? Answer:"
outputs = vqa_pipeline(image=image, question=question)
self.assertEqual(outputs, [{"answer": "two"}])
outputs = vqa_pipeline({"image": image, "question": question})
self.assertEqual(outputs, [{"answer": "two"}])
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
self.assertEqual(outputs, [[{"answer": "two"}]] * 2)
@require_tf @require_tf
@unittest.skip("Visual question answering not implemented in TF") @unittest.skip("Visual question answering not implemented in TF")
def test_small_model_tf(self): def test_small_model_tf(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment