"docs/source/vscode:/vscode.git/clone" did not exist on "9edf37583411f892cea9ae7d98156c85d7c087b1"
Commit 9a24e0cf authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Refactored qa pipeline argument handling + unittests

parent 63e36007
......@@ -98,14 +98,11 @@ class TextClassificationPipeline(Pipeline):
class QuestionAnsweringPipeline(Pipeline):
"""
Question Answering pipeling involving Tokenization and Inference.
TODO:
- top-k answers
- return start/end chars
- return score
"""
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
super().__init__(model, tokenizer)
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
......@@ -116,6 +113,55 @@ class QuestionAnsweringPipeline(Pipeline):
else:
return SquadExample(None, question, context, None, None, None)
@staticmethod
def handle_args(*inputs, **kwargs) -> List[SquadExample]:
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if inputs is not None and len(inputs) > 1:
kwargs['X'] = inputs
# Generic compatibility with sklearn and Keras
# Batched data
if 'X' in kwargs or 'data' in kwargs:
data = kwargs['X'] if 'X' in kwargs else kwargs['data']
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
if isinstance(item, dict):
if any(k not in item for k in ['question', 'context']):
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
data[i] = QuestionAnsweringPipeline.create_sample(**item)
elif isinstance(item, SquadExample):
continue
else:
raise ValueError(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.format('X' if 'X' in kwargs else 'data')
)
inputs = data
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
if isinstance(kwargs['question'], str):
kwargs['question'] = [kwargs['question']]
if isinstance(kwargs['context'], str):
kwargs['context'] = [kwargs['context']]
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
if not isinstance(inputs, list):
inputs = [inputs]
return inputs
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
super().__init__(model, tokenizer)
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
args = ['input_ids', 'attention_mask']
model_type = type(self.model).__name__.lower()
......@@ -131,10 +177,6 @@ class QuestionAnsweringPipeline(Pipeline):
else:
return {k: [feature.__dict__[k] for feature in features] for k in args}
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
def __call__(self, *texts, **kwargs):
# Set defaults values
kwargs.setdefault('topk', 1)
......@@ -149,29 +191,10 @@ class QuestionAnsweringPipeline(Pipeline):
if kwargs['max_answer_len'] < 1:
raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len']))
# Position args
if texts is not None and len(texts) > 1:
(texts, ) = texts
# Generic compatibility with sklearn and Keras
elif 'X' in kwargs and not texts:
texts = kwargs.pop('X')
# Batched data
elif 'data' in kwargs:
texts = kwargs.pop('data')
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
texts = QuestionAnsweringPipeline.create_sample(kwargs['question'], kwargs['context'])
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
if not isinstance(texts, list):
texts = [texts]
examples = QuestionAnsweringPipeline.handle_args(texts, **kwargs)
# Convert inputs to features
features = squad_convert_examples_to_features(texts, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
fw_args = self.inputs_for_model(features)
if is_tf_available():
......@@ -188,7 +211,7 @@ class QuestionAnsweringPipeline(Pipeline):
start, end = start.cpu().numpy(), end.cpu().numpy()
answers = []
for (example, feature, start_, end_) in zip(texts, features, start, end):
for (example, feature, start_, end_) in zip(examples, features, start, end):
# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_) / np.sum(np.exp(start_))
end_ = np.exp(end_) / np.sum(np.exp(end_))
......
......@@ -40,7 +40,38 @@ class QuestionAnsweringPipelineTest(unittest.TestCase):
# Batch case with topk = 2
a = nlp(question=['What is the name of the company I\'m working for ?', 'Where is the company based ?'],
context=['I\'m working for Huggingface.', 'The company is based in New York and Paris'], topk=2)
context=['Where is the company based ?', 'The company is based in New York and Paris'], topk=2)
self.check_answer_structure(a, 2, 2)
# check for data keyword
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'))
self.check_answer_structure(a, 1, 1)
a = nlp(data=nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'), topk=2)
self.check_answer_structure(a, 1, 2)
a = nlp(data=[
nlp.create_sample(question='What is the name of the company I\'m working for ?', context='I\'m working for Huggingface.'),
nlp.create_sample(question='I\'m working for Huggingface.', context='The company is based in New York and Paris'),
])
self.check_answer_structure(a, 2, 1)
a = nlp(data=[
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
])
self.check_answer_structure(a, 2, 1)
# X keywords
a = nlp(X=nlp.create_sample(
question='Where is the company based ?', context='The company is based in New York and Paris'
))
self.check_answer_structure(a, 1, 1)
a = nlp(X=[
{'question': 'What is the name of the company I\'m working for ?', 'context': 'I\'m working for Huggingface.'},
{'question': 'Where is the company based ?', 'context': 'The company is based in New York and Paris'},
], topk=2)
self.check_answer_structure(a, 2, 2)
@patch('transformers.pipelines.is_torch_available', return_value=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment