Commit cbc5c9c8 authored by Leo Gao's avatar Leo Gao
Browse files

squad: fix aggregation

parent 14dd29c4
...@@ -3,6 +3,19 @@ from math import exp ...@@ -3,6 +3,19 @@ from math import exp
from lm_eval.base import rf from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean from lm_eval.metrics import f1_score, mean
from . common import HFTask from . common import HFTask
from functools import partial
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD(HFTask): class SQuAD(HFTask):
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
...@@ -63,34 +76,31 @@ class SQuAD(HFTask): ...@@ -63,34 +76,31 @@ class SQuAD(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
squad_metric = datasets.load_metric("squad_v2") continuation, (logprob_unanswerable, _) = results
continuation, is_unanswerable = results
logprob_unanswerable, is_greedy = is_unanswerable
no_answer_probability = exp(logprob_unanswerable) no_answer_probability = exp(logprob_unanswerable)
predictions = [{ predictions = {
'id': doc['id'], 'id': doc['id'],
'prediction_text': continuation, 'prediction_text': continuation,
'no_answer_probability': no_answer_probability, 'no_answer_probability': no_answer_probability,
}] }
references = [{ references = {
'id': doc['id'], 'id': doc['id'],
'answers': doc['answers'], 'answers': doc['answers'],
}] }
metrics = squad_metric.compute(predictions=predictions, references=references)
metrics.pop('total', None)
metrics.pop('HasAns_total', None)
metrics.pop('NoAns_total', None)
metrics.pop('best_exact_thresh', None)
metrics.pop('best_f1_thresh', None)
return metrics return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -99,14 +109,14 @@ class SQuAD(HFTask): ...@@ -99,14 +109,14 @@ class SQuAD(HFTask):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {
'exact': mean, # Exact match (the normalized answer exactly match the gold answer) 'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': mean, # The F-score of predicted tokens versus the gold answer 'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': mean, # Exact match (the normalized answer exactly match the gold answer) 'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': mean, # The F-score of predicted tokens versus the gold answer 'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': mean, # Exact match (the normalized answer exactly match the gold answer) 'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': mean, # The F-score of predicted tokens versus the gold answer 'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': mean, # Best exact match (with varying threshold) 'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': mean, # Best F1 (with varying threshold) 'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
} }
def higher_is_better(self): def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment