Commit 0601d0bb authored by Anthony DiPofi's avatar Anthony DiPofi
Browse files

add evaluation for TriviaQA dataset based on loglikelihood

parent 0f536808
...@@ -15,6 +15,7 @@ from . import sat ...@@ -15,6 +15,7 @@ from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import piqa from . import piqa
from . import triviaqa
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
...@@ -42,6 +43,7 @@ TASK_REGISTRY = { ...@@ -42,6 +43,7 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
...@@ -2,7 +2,6 @@ from lm_eval.base import Dataset, rf, mean ...@@ -2,7 +2,6 @@ from lm_eval.base import Dataset, rf, mean
from lm_eval.utils import sh from lm_eval.utils import sh
import json import json
import requests import requests
import ftfy
import math import math
from best_download import download_file from best_download import download_file
......
import os import os
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset, mean, rf
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Dataset):
...@@ -37,52 +37,27 @@ class TriviaQA(Dataset): ...@@ -37,52 +37,27 @@ class TriviaQA(Dataset):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ''.join(['Q: ', doc['Question'], '\n\n','A: ']) return ''.join(['Q:', doc['Question'], '\n\n','A:'])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['Answer']['Aliases'][0] return doc['Answer']['Aliases'][0]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ll, is_prediction = rf.loglikelihood(ctx,doc['Answer']['Value'])
Requests which will be sent to the LM. return is_prediction
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a is_prediction = results
dict where keys are the names of submetrics and values are the values of return {
the metric for that one document "acc": float(is_prediction[1])
}
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} "acc": mean,
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} "acc": True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
import os import os
from functools import reduce from functools import reduce
import operator import operator
import lm_dataformat as lmd
from tqdm import tqdm from tqdm import tqdm
import json import json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment