Commit 35d809c3 authored by silentv0x's avatar silentv0x
Browse files

Fixes DROP implementation

Following the official implementation, the following changes were made
 - Validated answers are considered gold answers
 - EM/F1 are the max over all gold answers
parent 198ca732
......@@ -14,6 +14,7 @@ Acknowledgement: This implementation is based on the official evaluation for `DR
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task):
VERSION = 0
......@@ -50,19 +51,34 @@ class DROP(Task):
"id": qa["query_id"],
"passage": doc["passage"],
"question": qa["question"],
"answers": self.get_answers(qa["answer"]),
"answers": self.get_answers(qa),
}
@classmethod
def get_answers(cls, answers):
# NOTE: We wrap every non-`list` answer into a list for uniformity.
if answers["number"] != "":
return [str(answers["number"])]
if answers["spans"] != []:
return answers["spans"]
return [" ".join([answers["date"]["day"],
answers["date"]["month"],
answers["date"]["year"]]).strip()]
def get_answers(cls, qa):
answers = []
answers_set = set()
candidates = [qa["answer"]] + qa.get("validated_answers", [])
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod
def parse_answer(cls, answer):
# NOTE: Everything is returned as a tuple for uniformity and hashability.
if answer["number"] != "":
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
......@@ -76,7 +92,7 @@ class DROP(Task):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"])
return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......@@ -89,9 +105,7 @@ class DROP(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
conts = []
for _ in doc["answers"]:
conts.append(rf.greedy_until(ctx, ["."]))
conts = [rf.greedy_until(ctx, ["."])]
return conts
def process_results(self, doc, results):
......@@ -105,66 +119,96 @@ class DROP(Task):
The results of the requests created in construct_requests.
"""
preds, golds = results, doc["answers"]
exact_match, f1_score = self.get_metrics(preds, golds)
max_em = 0
max_f1 = 0
for gold_answer in golds:
exact_match, f1_score = self.get_metrics(preds, gold_answer)
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {
"em": exact_match,
"f1": f1_score
"em": max_em,
"f1": max_f1
}
def get_metrics(self, preds, golds):
exact_match = self._exact_match(preds, golds)
f1_score = self._f1_score(preds, golds)
return exact_match, f1_score
def _exact_match(self, preds, golds):
""" Returns the exact match of normalized gold answers and predictions. """
normalized_preds = [self._normalize(pred) for pred in preds]
normalized_golds = [self._normalize(gold) for gold in golds]
is_equal_sets = set(normalized_preds) == set(normalized_golds)
is_equal_length = len(normalized_preds) == len(normalized_golds)
return int(is_equal_sets and is_equal_length)
def _f1_score(self, preds, golds):
"""Returns the average F1-score over normalized gold answers and predictions.
From Section 5 of Dua et al. "DROP:...":
"When an answer has multiple spans, we first perform a one-to-one
alignment greedily based on bag-of-word overlap on the set of spans
and then compute average F1 over each span."
def get_metrics(self, predicted, gold):
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
pred_bags = self._answer_to_bags(preds)
gold_bags = self._answer_to_bags(golds)
f1_per_bag = self._align_bags(pred_bags, gold_bags)
return np.mean(f1_per_bag)
def _answer_to_bags(self, answers):
return [set(self._normalize(answer).split()) for answer in answers]
def _align_bags(self, pred_bags, gold_bags):
""" Returns the max metric value over all the answers. """
scores = np.zeros([len(gold_bags), len(pred_bags)])
for gold_index, gold_bag in enumerate(gold_bags):
for pred_index, pred_bag in enumerate(pred_bags):
if self._is_number_match(pred_bag, gold_bag):
scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag)
predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
exact_match = 1.0
else:
exact_match = 0.0
f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
def _answer_to_bags(self, answer):
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = self._normalize(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(self, predicted, gold):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _bag_f1(self, pred_bag, gold_bag):
intersection = len(gold_bag.intersection(pred_bag))
if intersection == 0:
return 0.0
precision = intersection / float(len(pred_bag)) if pred_bag else 1.0
recall = intersection / float(len(gold_bag)) if gold_bag else 1.0
f1 = (2 * precision * recall) / (precision + recall)
def _compute_f1(self, predicted_bag, gold_bag):
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1
def _is_number_match(self, pred_bag, gold_bag):
pred_numbers = set([word for word in pred_bag if self._is_number(word)])
gold_numbers = set([word for word in gold_bag if self._is_number(word)])
if (not gold_numbers) or gold_numbers.intersection(pred_numbers):
def _match_numbers_if_present(self, gold_bag, predicted_bag):
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if self._is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if self._is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
......@@ -175,30 +219,29 @@ class DROP(Task):
except ValueError:
return False
def _normalize(self, answer):
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def _remove_articles(self, text):
return _ARTICLES.sub(" ", text)
def white_space_fix(text):
return " ".join(text.split())
def _white_space_fix(self, text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
if not self._is_number(text):
return "".join(ch for ch in text if ch not in exclude)
else:
return text
def _remove_punc(self, text):
exclude = set(string.punctuation)
if not self._is_number(text):
return "".join(ch for ch in text if ch not in exclude)
else:
return text
def fix_number(text):
return str(float(text)) if self._is_number(text) else text
def _fix_number(self, text):
return str(float(text)) if self._is_number(text) else text
def tokenize(text):
return re.split(" |-", text)
def _tokenize(text):
return re.split(" |-", text)
def _normalize(self, answer):
tokens = [
white_space_fix(remove_articles(fix_number(remove_punc(token.lower()))))
for token in tokenize(answer)
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
for token in self._tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment