"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "eab0afc19ceaf9a31190777f5548312d2346cd44"
Unverified Commit 2b282296 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `batch_size` test to QA pipeline. (#17330)

parent a4386d7e
...@@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ...@@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
) )
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
# Using batch is OK
new_outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
)
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
self.assertEqual(outputs, new_outputs)
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
question_answerer = pipeline( question_answerer = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment