Commit 96f3e5b3 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Disable extractive spans

parent c6a35696
...@@ -68,7 +68,7 @@ def categorise_answer(answer_blob): ...@@ -68,7 +68,7 @@ def categorise_answer(answer_blob):
return answer, answer_type return answer, answer_type
elif answer_blob["extractive_spans"]: elif answer_blob["extractive_spans"]:
answer = answer_blob["extractive_spans"] answer = answer_blob["extractive_spans"]
answer_type = "extractive spans" answer_type = "extractive_spans"
return answer, answer_type return answer, answer_type
elif answer_blob["yes_no"] is False: elif answer_blob["yes_no"] is False:
answer = "No" answer = "No"
...@@ -119,12 +119,15 @@ class QASPER(HFTask): ...@@ -119,12 +119,15 @@ class QASPER(HFTask):
+ "Q: " + "Q: "
+ doc["question"] + doc["question"]
+ "\n\n" + "\n\n"
+ "A: " + "A:"
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this method is invoked by tests only # this method is invoked by tests only
return " " + doc["answer"] answer = doc["answer"]
if isinstance(answer, list):
answer = ", ".join(answer)
return " " + answer
def training_docs(self): def training_docs(self):
for doc in self.data["train"]: for doc in self.data["train"]:
...@@ -156,28 +159,34 @@ class QASPER(HFTask): ...@@ -156,28 +159,34 @@ class QASPER(HFTask):
return obs_list return obs_list
def process_results(self, doc, results): def process_results(self, doc, results):
res, unanswerable = results # TODO: Calculate a score for extractive spans once a request type for generating
# extractive spans is available
if len(results) == 1:
[(logprob_unanswerable, _)] = results
elif len(results) == 2:
res, (logprob_unanswerable, _) = results
else:
ll_yes, ll_no, (logprob_unanswerable, _) = results
res_dict = {} res_dict = {}
# Handle unanswerability first # Handle unanswerability first
unanswerable_gold = doc["answer_type"] == "unanswerable" unanswerable_gold = doc["answer_type"] == "unanswerable"
unanswerable_pred = exp(unanswerable) > 1 - exp(unanswerable) unanswerable_pred = exp(logprob_unanswerable) > 1 - exp(logprob_unanswerable)
res_dict["f1_un"] = (unanswerable_gold, unanswerable_pred) res_dict["f1_un"] = (unanswerable_gold, unanswerable_pred)
# Handle yes/no questions # Handle yes/no questions
if doc["answer_type"] == "bool": if doc["answer_type"] == "bool":
ll_yes, ll_no = res
gold = 1 if doc["answer"] == "yes" else 0 gold = 1 if doc["answer"] == "yes" else 0
pred = ll_yes > ll_no pred = ll_yes > ll_no
res_dict["f1_yn"] = (gold, pred) res_dict["f1_yn"] = (gold, pred)
# Handle completions # Handle completions
if doc["answer_type"] == "free form answer": if doc["answer_type"] == "free form answer":
res_dict["f1_ab"] = token_f1_score(res["answer"], doc["answer"]) res_dict["f1_ab"] = token_f1_score(res, doc["answer"])
# Handle extraction # Handle extraction
if doc["answer_type"] == "extractive spans": if doc["answer_type"] == "extractive_spans":
res_dict["f1_ex"] = paragraph_f1_score(res["answer"], doc["answer"]) res_dict["f1_ex"] = 0
return res_dict return res_dict
def aggregation(self): def aggregation(self):
...@@ -195,14 +204,14 @@ class QASPER(HFTask): ...@@ -195,14 +204,14 @@ class QASPER(HFTask):
part of the document for `doc`. part of the document for `doc`.
""" """
unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer", "extractive spans"): if doc["answer_type"] in ("free form answer", "extractive_spans"):
return rf.greedy_until(ctx, ["\n"]), unanswerable return [rf.greedy_until(ctx, ["\n"]), unanswerable]
elif doc["answer_type"] in ("bool"): elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no") ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no, unanswerable return [ll_yes, ll_no, unanswerable]
else: else:
return unanswerable return [unanswerable]
def higher_is_better(self): def higher_is_better(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment