"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "8985f6c207fa36f52eb562f68afcc26c4dee149c"
Unverified Commit 7342d9a5 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Improve QA pipeline error handling (#8286)

- The issue is that with previous code we would have the following:

```python
qa_pipeline = (...)
qa_pipeline(question="Where was he born ?", context="")
-> IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
```

The goal here is to improve this to actually return a ValueError
wherever possible.

While at it, I tried to simplify QuestionArgumentHandler's code to
make it smaller and more compat while keeping backward compat.
parent 38630e7a
...@@ -22,6 +22,7 @@ import sys ...@@ -22,6 +22,7 @@ import sys
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
...@@ -1597,55 +1598,52 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): ...@@ -1597,55 +1598,52 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
command-line supplied arguments. command-line supplied arguments.
""" """
def normalize(self, item):
if isinstance(item, SquadExample):
return item
elif isinstance(item, dict):
for k in ["question", "context"]:
if k not in item:
raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
elif item[k] is None:
raise ValueError("`{}` cannot be None".format(k))
elif isinstance(item[k], str) and len(item[k]) == 0:
raise ValueError("`{}` cannot be empty".format(k))
return QuestionAnsweringPipeline.create_sample(**item)
raise ValueError("{} argument needs to be of type (SquadExample, dict)".format(item))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating # Detect where the actual inputs are
if args is not None and len(args) > 0: if args is not None and len(args) > 0:
if len(args) == 1: if len(args) == 1:
kwargs["X"] = args[0] inputs = args[0]
elif len(args) == 2 and {type(el) for el in args} == {str}:
inputs = [{"question": args[0], "context": args[1]}]
else: else:
kwargs["X"] = list(args) inputs = list(args)
# Generic compatibility with sklearn and Keras # Generic compatibility with sklearn and Keras
# Batched data # Batched data
if "X" in kwargs or "data" in kwargs: elif "X" in kwargs:
inputs = kwargs["X"] if "X" in kwargs else kwargs["data"] inputs = kwargs["X"]
elif "data" in kwargs:
if isinstance(inputs, dict): inputs = kwargs["data"]
inputs = [inputs]
else:
# Copy to avoid overriding arguments
inputs = [i for i in inputs]
for i, item in enumerate(inputs):
if isinstance(item, dict):
if any(k not in item for k in ["question", "context"]):
raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
inputs[i] = QuestionAnsweringPipeline.create_sample(**item)
elif not isinstance(item, SquadExample):
raise ValueError(
"{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
"X" if "X" in kwargs else "data"
)
)
# Tabular input
elif "question" in kwargs and "context" in kwargs: elif "question" in kwargs and "context" in kwargs:
if isinstance(kwargs["question"], str): inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
kwargs["question"] = [kwargs["question"]]
if isinstance(kwargs["context"], str):
kwargs["context"] = [kwargs["context"]]
inputs = [
QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
]
else: else:
raise ValueError("Unknown arguments {}".format(kwargs)) raise ValueError("Unknown arguments {}".format(kwargs))
if not isinstance(inputs, list): # Normalize inputs
if isinstance(inputs, dict):
inputs = [inputs] inputs = [inputs]
elif isinstance(inputs, Iterable):
# Copy to avoid overriding arguments
inputs = [i for i in inputs]
else:
raise ValueError("Invalid arguments {}".format(inputs))
for i, item in enumerate(inputs):
inputs[i] = self.normalize(item)
return inputs return inputs
......
import unittest import unittest
from transformers.pipelines import Pipeline from transformers.data.processors.squad import SquadExample
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler
from .test_pipelines_common import CustomInputPipelineCommonMixin from .test_pipelines_common import CustomInputPipelineCommonMixin
...@@ -43,5 +44,116 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): ...@@ -43,5 +44,116 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
for key in output_keys: for key in output_keys:
self.assertIn(key, result) self.assertIn(key, result)
for bad_input in invalid_inputs: for bad_input in invalid_inputs:
self.assertRaises(Exception, nlp, bad_input) self.assertRaises(ValueError, nlp, bad_input)
self.assertRaises(Exception, nlp, invalid_inputs) self.assertRaises(ValueError, nlp, invalid_inputs)
def test_argument_handler(self):
qa = QuestionAnsweringArgumentHandler()
Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"
normalized = qa(Q, C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(question=Q, context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(question=Q, context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa({"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa([{"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa([{"question": Q, "context": C}, {"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(X={"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(X=[{"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(data={"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
def test_argument_handler_error_handling(self):
qa = QuestionAnsweringArgumentHandler()
Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"
with self.assertRaises(KeyError):
qa({"context": C})
with self.assertRaises(KeyError):
qa({"question": Q})
with self.assertRaises(KeyError):
qa([{"context": C}])
with self.assertRaises(ValueError):
qa(None, C)
with self.assertRaises(ValueError):
qa("", C)
with self.assertRaises(ValueError):
qa(Q, None)
with self.assertRaises(ValueError):
qa(Q, "")
with self.assertRaises(ValueError):
qa(question=None, context=C)
with self.assertRaises(ValueError):
qa(question="", context=C)
with self.assertRaises(ValueError):
qa(question=Q, context=None)
with self.assertRaises(ValueError):
qa(question=Q, context="")
with self.assertRaises(ValueError):
qa({"question": None, "context": C})
with self.assertRaises(ValueError):
qa({"question": "", "context": C})
with self.assertRaises(ValueError):
qa({"question": Q, "context": None})
with self.assertRaises(ValueError):
qa({"question": Q, "context": ""})
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": None, "context": C}])
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": "", "context": C}])
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": None}])
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
def test_argument_handler_error_handling_odd(self):
qa = QuestionAnsweringArgumentHandler()
with self.assertRaises(ValueError):
qa(None)
with self.assertRaises(ValueError):
qa(Y=None)
with self.assertRaises(ValueError):
qa(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment