Unverified Commit 67403413 authored by Ankur Goyal's avatar Ankur Goyal Committed by GitHub
Browse files

Change document question answering pipeline to always return an array (#19071)


Co-authored-by: default avatarAnkur Goyal <ankur@impira.com>
parent cc567e00
...@@ -383,8 +383,6 @@ class DocumentQuestionAnsweringPipeline(Pipeline): ...@@ -383,8 +383,6 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
if len(answers) == 1:
return answers[0]
return answers return answers
def postprocess_donut(self, model_outputs, **kwargs): def postprocess_donut(self, model_outputs, **kwargs):
......
...@@ -267,7 +267,7 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -267,7 +267,7 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
image = INVOICE_URL image = INVOICE_URL
question = "What is the invoice number?" question = "What is the invoice number?"
outputs = dqa_pipeline(image=image, question=question, top_k=2) outputs = dqa_pipeline(image=image, question=question, top_k=2)
self.assertEqual(nested_simplify(outputs, decimals=4), {"answer": "us-001"}) self.assertEqual(nested_simplify(outputs, decimals=4), [{"answer": "us-001"}])
@require_tf @require_tf
@unittest.skip("Document question answering not implemented in TF") @unittest.skip("Document question answering not implemented in TF")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment