Unverified Commit 6a9726ec authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `DocumentQuestionAnsweringPipelineTests` (#19023)



* Fix DocumentQuestionAnsweringPipelineTests
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1207deb8
...@@ -113,13 +113,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -113,13 +113,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
question = "How many cats are there?" question = "How many cats are there?"
expected_output = [ expected_output = [
{ {"score": 0.0001, "answer": "oy 2312/2019", "start": 38, "end": 39},
"score": 0.0001, {"score": 0.0001, "answer": "oy 2312/2019 DUE", "start": 38, "end": 40},
"answer": "2312/2019 DUE DATE 26102/2019 ay DESCRIPTION UNIT PRICE",
"start": 38,
"end": 45,
},
{"score": 0.0001, "answer": "2312/2019 DUE", "start": 38, "end": 39},
] ]
outputs = dqa_pipeline(image=image, question=question, top_k=2) outputs = dqa_pipeline(image=image, question=question, top_k=2)
self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) self.assertEqual(nested_simplify(outputs, decimals=4), expected_output)
...@@ -170,8 +165,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -170,8 +165,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
...@@ -179,8 +174,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -179,8 +174,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
...@@ -191,8 +186,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -191,8 +186,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16},
], ],
] ]
* 2, * 2,
...@@ -219,8 +214,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -219,8 +214,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23},
], ],
) )
...@@ -228,8 +223,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -228,8 +223,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23},
], ],
) )
...@@ -240,8 +235,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -240,8 +235,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23},
] ]
] ]
* 2, * 2,
...@@ -254,8 +249,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -254,8 +249,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23},
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment