Unverified Commit 66558b35 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #107 from EleutherAI/triviaqa_evaluation

add evaluation for TriviaQA dataset based on loglikelihood method
parents 0f30237a c8032a1a
......@@ -16,6 +16,7 @@ from . import arithmetic
from . import lambada
from . import race
from . import piqa
from . import triviaqa
TASK_REGISTRY = {
......@@ -44,6 +45,7 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet
......
import os
import json
import random
from lm_eval.base import Dataset
from lm_eval.base import Dataset, mean, rf
from ..utils import sh
class TriviaQA(Dataset):
......@@ -37,52 +37,41 @@ class TriviaQA(Dataset):
return ""
def doc_to_text(self, doc):
return ''.join(['Q: ', doc['Question'], '\n\n','A: '])
return ''.join(['Q:', doc['Question'], '\n\n','A:'])
def doc_to_target(self, doc):
return doc['Answer']['Aliases'][0]
return " " + doc['Answer']['Value']
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": float(any(results))
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
return {
"acc": True
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment