Unverified Commit 66558b35 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #107 from EleutherAI/triviaqa_evaluation

add evaluation for TriviaQA dataset based on loglikelihood method
parents 0f30237a c8032a1a
...@@ -16,6 +16,7 @@ from . import arithmetic ...@@ -16,6 +16,7 @@ from . import arithmetic
from . import lambada from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import triviaqa
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -44,6 +45,7 @@ TASK_REGISTRY = { ...@@ -44,6 +45,7 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
import os import os
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset, mean, rf
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Dataset):
...@@ -37,52 +37,41 @@ class TriviaQA(Dataset): ...@@ -37,52 +37,41 @@ class TriviaQA(Dataset):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ''.join(['Q: ', doc['Question'], '\n\n','A: ']) return ''.join(['Q:', doc['Question'], '\n\n','A:'])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['Answer']['Aliases'][0] return " " + doc['Answer']['Value']
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ret = []
Requests which will be sent to the LM. for alias in self._remove_prefixes(doc['Answer']['Aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a return {
dict where keys are the names of submetrics and values are the values of "acc": float(any(results))
the metric for that one document }
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} "acc": mean,
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} "acc": True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment