Commit 35d809c3 authored by silentv0x's avatar silentv0x
Browse files

Fixes DROP implementation

Following the official implementation, the following changes were made
 - Validated answers are considered gold answers
 - EM/F1 are the max over all gold answers
parent 198ca732
...@@ -14,6 +14,7 @@ Acknowledgement: This implementation is based on the official evaluation for `DR ...@@ -14,6 +14,7 @@ Acknowledgement: This implementation is based on the official evaluation for `DR
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
""" """
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(Task):
VERSION = 0 VERSION = 0
...@@ -50,19 +51,34 @@ class DROP(Task): ...@@ -50,19 +51,34 @@ class DROP(Task):
"id": qa["query_id"], "id": qa["query_id"],
"passage": doc["passage"], "passage": doc["passage"],
"question": qa["question"], "question": qa["question"],
"answers": self.get_answers(qa["answer"]), "answers": self.get_answers(qa),
} }
@classmethod @classmethod
def get_answers(cls, answers): def get_answers(cls, qa):
# NOTE: We wrap every non-`list` answer into a list for uniformity. answers = []
if answers["number"] != "": answers_set = set()
return [str(answers["number"])]
if answers["spans"] != []: candidates = [qa["answer"]] + qa.get("validated_answers", [])
return answers["spans"] for candidate in candidates:
return [" ".join([answers["date"]["day"], answer = cls.parse_answer(candidate)
answers["date"]["month"], if answer in answers_set:
answers["date"]["year"]]).strip()] continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod
def parse_answer(cls, answer):
# NOTE: Everything is returned as a tuple for uniformity and hashability.
if answer["number"] != "":
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
def training_docs(self): def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json")) docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
...@@ -76,7 +92,7 @@ class DROP(Task): ...@@ -76,7 +92,7 @@ class DROP(Task):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"]) return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
...@@ -89,9 +105,7 @@ class DROP(Task): ...@@ -89,9 +105,7 @@ class DROP(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
conts = [] conts = [rf.greedy_until(ctx, ["."])]
for _ in doc["answers"]:
conts.append(rf.greedy_until(ctx, ["."]))
return conts return conts
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -105,66 +119,96 @@ class DROP(Task): ...@@ -105,66 +119,96 @@ class DROP(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
preds, golds = results, doc["answers"] preds, golds = results, doc["answers"]
exact_match, f1_score = self.get_metrics(preds, golds) max_em = 0
max_f1 = 0
for gold_answer in golds:
exact_match, f1_score = self.get_metrics(preds, gold_answer)
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return { return {
"em": exact_match, "em": max_em,
"f1": f1_score "f1": max_f1
} }
def get_metrics(self, preds, golds): def get_metrics(self, predicted, gold):
exact_match = self._exact_match(preds, golds) """
f1_score = self._f1_score(preds, golds) Takes a predicted answer and a gold answer (that are both either a string or a list of
return exact_match, f1_score strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
def _exact_match(self, preds, golds): validation, or while training), this is the function you want to call, after using
""" Returns the exact match of normalized gold answers and predictions. """ :func:`answer_json_to_strings` when reading the gold answer from the released data file.
normalized_preds = [self._normalize(pred) for pred in preds]
normalized_golds = [self._normalize(gold) for gold in golds]
is_equal_sets = set(normalized_preds) == set(normalized_golds)
is_equal_length = len(normalized_preds) == len(normalized_golds)
return int(is_equal_sets and is_equal_length)
def _f1_score(self, preds, golds):
"""Returns the average F1-score over normalized gold answers and predictions.
From Section 5 of Dua et al. "DROP:...":
"When an answer has multiple spans, we first perform a one-to-one
alignment greedily based on bag-of-word overlap on the set of spans
and then compute average F1 over each span."
""" """
pred_bags = self._answer_to_bags(preds) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(golds) gold_bags = self._answer_to_bags(gold)
f1_per_bag = self._align_bags(pred_bags, gold_bags)
return np.mean(f1_per_bag) if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
exact_match = 1.0
def _answer_to_bags(self, answers): else:
return [set(self._normalize(answer).split()) for answer in answers] exact_match = 0.0
def _align_bags(self, pred_bags, gold_bags): f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
""" Returns the max metric value over all the answers. """ f1 = np.mean(f1_per_bag)
scores = np.zeros([len(gold_bags), len(pred_bags)]) f1 = round(f1, 2)
for gold_index, gold_bag in enumerate(gold_bags): return exact_match, f1
for pred_index, pred_bag in enumerate(pred_bags):
if self._is_number_match(pred_bag, gold_bag): def _answer_to_bags(self, answer):
scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag) if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = self._normalize(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(self, predicted, gold):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind): for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column]) max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores return max_scores
def _bag_f1(self, pred_bag, gold_bag): def _compute_f1(self, predicted_bag, gold_bag):
intersection = len(gold_bag.intersection(pred_bag)) intersection = len(gold_bag.intersection(predicted_bag))
if intersection == 0: if not predicted_bag:
return 0.0 precision = 1.0
precision = intersection / float(len(pred_bag)) if pred_bag else 1.0 else:
recall = intersection / float(len(gold_bag)) if gold_bag else 1.0 precision = intersection / float(len(predicted_bag))
f1 = (2 * precision * recall) / (precision + recall) if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1 return f1
def _is_number_match(self, pred_bag, gold_bag): def _match_numbers_if_present(self, gold_bag, predicted_bag):
pred_numbers = set([word for word in pred_bag if self._is_number(word)]) gold_numbers = set()
gold_numbers = set([word for word in gold_bag if self._is_number(word)]) predicted_numbers = set()
if (not gold_numbers) or gold_numbers.intersection(pred_numbers): for word in gold_bag:
if self._is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if self._is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True return True
return False return False
...@@ -175,30 +219,29 @@ class DROP(Task): ...@@ -175,30 +219,29 @@ class DROP(Task):
except ValueError: except ValueError:
return False return False
def _normalize(self, answer): def _remove_articles(self, text):
def remove_articles(text): return _ARTICLES.sub(" ", text)
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text): def _white_space_fix(self, text):
return " ".join(text.split()) return " ".join(text.split())
def remove_punc(text): def _remove_punc(self, text):
exclude = set(string.punctuation) exclude = set(string.punctuation)
if not self._is_number(text): if not self._is_number(text):
return "".join(ch for ch in text if ch not in exclude) return "".join(ch for ch in text if ch not in exclude)
else: else:
return text return text
def fix_number(text): def _fix_number(self, text):
return str(float(text)) if self._is_number(text) else text return str(float(text)) if self._is_number(text) else text
def tokenize(text): def _tokenize(text):
return re.split(" |-", text) return re.split(" |-", text)
def _normalize(self, answer):
tokens = [ tokens = [
white_space_fix(remove_articles(fix_number(remove_punc(token.lower())))) self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
for token in tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip() normalized = " ".join(tokens).strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment