Unverified Commit 285c6262 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding a test to prevent late failure in the Table question answering (#9808)

pipeline.

- If table is empty then the line that contain `answer[0]` will fail.
- This PR add a check to prevent `answer[0]`.
- Also adds an early check for presence of `table` and `query` to
prevent late failure and give better error message.
- Adds a few tests to make sure these errors are correctly raised.
parent a46050d0
...@@ -3,7 +3,7 @@ import collections ...@@ -3,7 +3,7 @@ import collections
import numpy as np import numpy as np
from ..file_utils import add_end_docstrings, is_torch_available, requires_pandas from ..file_utils import add_end_docstrings, is_torch_available, requires_pandas
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException
if is_torch_available(): if is_torch_available():
...@@ -239,6 +239,10 @@ class TableQuestionAnsweringPipeline(Pipeline): ...@@ -239,6 +239,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
batched_answers = [] batched_answers = []
for pipeline_input in pipeline_inputs: for pipeline_input in pipeline_inputs:
table, query = pipeline_input["table"], pipeline_input["query"] table, query = pipeline_input["table"], pipeline_input["query"]
if table.empty:
raise ValueError("table is empty")
if not query:
raise ValueError("query is empty")
inputs = self.tokenizer( inputs = self.tokenizer(
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
) )
...@@ -276,5 +280,7 @@ class TableQuestionAnsweringPipeline(Pipeline): ...@@ -276,5 +280,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
answer["aggregator"] = aggregator answer["aggregator"] = aggregator
answers.append(answer) answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
batched_answers.append(answers if len(answers) > 1 else answers[0]) batched_answers.append(answers if len(answers) > 1 else answers[0])
return batched_answers if len(batched_answers) > 1 else batched_answers[0] return batched_answers if len(batched_answers) > 1 else batched_answers[0]
...@@ -131,6 +131,49 @@ class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): ...@@ -131,6 +131,49 @@ class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict) self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int) self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
with self.assertRaises(ValueError):
table_querier(
{
"table": {},
"query": "how many movies has george clooney played in?",
}
)
with self.assertRaises(ValueError):
table_querier(
{
"query": "how many movies has george clooney played in?",
}
)
with self.assertRaises(ValueError):
table_querier(
{
"table": {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
},
"query": "",
}
)
with self.assertRaises(ValueError):
table_querier(
{
"table": {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
},
}
)
def test_empty_errors(self):
table_querier = pipeline(
"table-question-answering",
model="lysandre/tiny-tapas-random-wtq",
tokenizer="lysandre/tiny-tapas-random-wtq",
)
mono_result = table_querier(self.valid_inputs[0], sequential=True) mono_result = table_querier(self.valid_inputs[0], sequential=True)
multi_result = table_querier(self.valid_inputs, sequential=True) multi_result = table_querier(self.valid_inputs, sequential=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment