Commit 5be42b4d authored by Charles Foster's avatar Charles Foster
Browse files

SQuAD fixed to use loglikelihood API to calculate the probability of an unanswerable question.

parent 884c29fb
import datasets
from lm_eval.base import rf, f1_score, mean
from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask
class SQuAD(HFTask):
......@@ -48,7 +50,8 @@ class SQuAD(HFTask):
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
return continuation
is_unanswerable = rf.loglikelihood(ctx, [' unanswerable'])
return continuation, is_unanswerable
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -62,12 +65,12 @@ class SQuAD(HFTask):
"""
squad_metric = datasets.load_metric("squad_v2")
continuation, = results
continuation, is_unanswerable = results
no_answer_probability = 0.0
if continuation.startswith(' unanswerable'):
no_answer_probability = 1.0
logprob_unanswerable, is_greedy = is_unanswerable
no_answer_probability = exp(logprob_unanswerable)
predictions = [{
'id': doc['id'],
'prediction_text': continuation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment