Commit 5be42b4d authored by Charles Foster's avatar Charles Foster
Browse files

SQuAD fixed to use loglikelihood API to calculate the probability of an unanswerable question.

parent 884c29fb
import datasets import datasets
from lm_eval.base import rf, f1_score, mean from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask from . common import HFTask
class SQuAD(HFTask): class SQuAD(HFTask):
...@@ -48,7 +50,8 @@ class SQuAD(HFTask): ...@@ -48,7 +50,8 @@ class SQuAD(HFTask):
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ['\n']) continuation = rf.greedy_until(ctx, ['\n'])
return continuation is_unanswerable = rf.loglikelihood(ctx, [' unanswerable'])
return continuation, is_unanswerable
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -62,12 +65,12 @@ class SQuAD(HFTask): ...@@ -62,12 +65,12 @@ class SQuAD(HFTask):
""" """
squad_metric = datasets.load_metric("squad_v2") squad_metric = datasets.load_metric("squad_v2")
continuation, = results continuation, is_unanswerable = results
no_answer_probability = 0.0 logprob_unanswerable, is_greedy = is_unanswerable
if continuation.startswith(' unanswerable'):
no_answer_probability = 1.0
no_answer_probability = exp(logprob_unanswerable)
predictions = [{ predictions = [{
'id': doc['id'], 'id': doc['id'],
'prediction_text': continuation, 'prediction_text': continuation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment