"...resnet50_tensorflow.git" did not exist on "37e76715b39cdf8c857e04ff6613a82b1fd84413"
Unverified Commit c19727fd authored by Santiago Castro's avatar Santiago Castro Committed by GitHub
Browse files

Add support for the null answer in `QuestionAnsweringPipeline` (#3441)

* Add support for the null answer in `QuestionAnsweringPipeline`

* black

* Fix min null score computation

* Fix a PR comment
parent edf0582c
......@@ -944,6 +944,7 @@ class QuestionAnsweringPipeline(Pipeline):
kwargs.setdefault("max_answer_len", 15)
kwargs.setdefault("max_seq_len", 384)
kwargs.setdefault("max_question_len", 64)
kwargs.setdefault("handle_impossible_answer", False)
if kwargs["topk"] < 1:
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
......@@ -982,6 +983,7 @@ class QuestionAnsweringPipeline(Pipeline):
start, end = self.model(**fw_args)
start, end = start.cpu().numpy(), end.cpu().numpy()
min_null_score = 1000000 # large and positive
answers = []
for (feature, start_, end_) in zip(features, start, end):
# Normalize logits and spans to retrieve the answer
......@@ -994,8 +996,9 @@ class QuestionAnsweringPipeline(Pipeline):
end_ * np.abs(np.array(feature.p_mask) - 1),
)
# TODO : What happens if not possible
# Mask CLS
if kwargs["handle_impossible_answer"]:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
start_[0] = end_[0] = 0
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
......@@ -1013,6 +1016,10 @@ class QuestionAnsweringPipeline(Pipeline):
}
for s, e, score in zip(starts, ends, scores)
]
if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
all_answers += answers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment