Commit 96f3e5b3 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Disable extractive spans

parent c6a35696
......@@ -68,7 +68,7 @@ def categorise_answer(answer_blob):
return answer, answer_type
elif answer_blob["extractive_spans"]:
answer = answer_blob["extractive_spans"]
answer_type = "extractive spans"
answer_type = "extractive_spans"
return answer, answer_type
elif answer_blob["yes_no"] is False:
answer = "No"
......@@ -119,12 +119,15 @@ class QASPER(HFTask):
+ "Q: "
+ doc["question"]
+ "\n\n"
+ "A: "
+ "A:"
)
def doc_to_target(self, doc):
# this method is invoked by tests only
return " " + doc["answer"]
answer = doc["answer"]
if isinstance(answer, list):
answer = ", ".join(answer)
return " " + answer
def training_docs(self):
for doc in self.data["train"]:
......@@ -156,28 +159,34 @@ class QASPER(HFTask):
return obs_list
def process_results(self, doc, results):
res, unanswerable = results
# TODO: Calculate a score for extractive spans once a request type for generating
# extractive spans is available
if len(results) == 1:
[(logprob_unanswerable, _)] = results
elif len(results) == 2:
res, (logprob_unanswerable, _) = results
else:
ll_yes, ll_no, (logprob_unanswerable, _) = results
res_dict = {}
# Handle unanswerability first
unanswerable_gold = doc["answer_type"] == "unanswerable"
unanswerable_pred = exp(unanswerable) > 1 - exp(unanswerable)
unanswerable_pred = exp(logprob_unanswerable) > 1 - exp(logprob_unanswerable)
res_dict["f1_un"] = (unanswerable_gold, unanswerable_pred)
# Handle yes/no questions
if doc["answer_type"] == "bool":
ll_yes, ll_no = res
gold = 1 if doc["answer"] == "yes" else 0
pred = ll_yes > ll_no
res_dict["f1_yn"] = (gold, pred)
# Handle completions
if doc["answer_type"] == "free form answer":
res_dict["f1_ab"] = token_f1_score(res["answer"], doc["answer"])
res_dict["f1_ab"] = token_f1_score(res, doc["answer"])
# Handle extraction
if doc["answer_type"] == "extractive spans":
res_dict["f1_ex"] = paragraph_f1_score(res["answer"], doc["answer"])
if doc["answer_type"] == "extractive_spans":
res_dict["f1_ex"] = 0
return res_dict
def aggregation(self):
......@@ -195,14 +204,14 @@ class QASPER(HFTask):
part of the document for `doc`.
"""
unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer", "extractive spans"):
return rf.greedy_until(ctx, ["\n"]), unanswerable
if doc["answer_type"] in ("free form answer", "extractive_spans"):
return [rf.greedy_until(ctx, ["\n"]), unanswerable]
elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no, unanswerable
return [ll_yes, ll_no, unanswerable]
else:
return unanswerable
return [unanswerable]
def higher_is_better(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment