Unverified Commit 896a0eb1 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2459 from Perseus14/patch-4

Update pipelines.py
parents a3085020 0d6c17fc
......@@ -705,9 +705,19 @@ class QuestionAnsweringPipeline(Pipeline):
# Convert inputs to features
examples = self._args_parser(*texts, **kwargs)
features = squad_convert_examples_to_features(
examples, self.tokenizer, kwargs["max_seq_len"], kwargs["doc_stride"], kwargs["max_question_len"], False
features_list = [
squad_convert_examples_to_features(
[example],
self.tokenizer,
kwargs["max_seq_len"],
kwargs["doc_stride"],
kwargs["max_question_len"],
False,
)
for example in examples
]
all_answers = []
for features, example in zip(features_list, examples):
fw_args = self.inputs_for_model([f.__dict__ for f in features])
# Manage tensor allocation on correct device
......@@ -724,13 +734,16 @@ class QuestionAnsweringPipeline(Pipeline):
start, end = start.cpu().numpy(), end.cpu().numpy()
answers = []
for (example, feature, start_, end_) in zip(examples, features, start, end):
for (feature, start_, end_) in zip(features, start, end):
# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_) / np.sum(np.exp(start_))
end_ = np.exp(end_) / np.sum(np.exp(end_))
# Mask padding and question
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
start_, end_ = (
start_ * np.abs(np.array(feature.p_mask) - 1),
end_ * np.abs(np.array(feature.p_mask) - 1),
)
# TODO : What happens if not possible
# Mask CLS
......@@ -751,9 +764,12 @@ class QuestionAnsweringPipeline(Pipeline):
}
for s, e, score in zip(starts, ends, scores)
]
if len(answers) == 1:
return answers[0]
return answers
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
all_answers += answers
if len(all_answers) == 1:
return all_answers[0]
return all_answers
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment