Commit cbc5c9c8 authored by Leo Gao's avatar Leo Gao
Browse files

squad: fix aggregation

parent 14dd29c4
......@@ -3,6 +3,19 @@ from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask
from functools import partial
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD(HFTask):
DATASET_PATH = "squad_v2"
......@@ -63,34 +76,31 @@ class SQuAD(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
squad_metric = datasets.load_metric("squad_v2")
continuation, is_unanswerable = results
logprob_unanswerable, is_greedy = is_unanswerable
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = [{
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}]
}
references = [{
references = {
'id': doc['id'],
'answers': doc['answers'],
}]
metrics = squad_metric.compute(predictions=predictions, references=references)
metrics.pop('total', None)
metrics.pop('HasAns_total', None)
metrics.pop('NoAns_total', None)
metrics.pop('best_exact_thresh', None)
metrics.pop('best_f1_thresh', None)
}
return metrics
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
"""
......@@ -99,14 +109,14 @@ class SQuAD(HFTask):
functions that aggregate a list of metrics
"""
return {
'exact': mean, # Exact match (the normalized answer exactly match the gold answer)
'f1': mean, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': mean, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': mean, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': mean, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': mean, # The F-score of predicted tokens versus the gold answer
'best_exact': mean, # Best exact match (with varying threshold)
'best_f1': mean, # Best F1 (with varying threshold)
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
}
def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment