Unverified Commit 2d61b3ce authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #103 from EleutherAI/piqa

Implement PiQA
parents a2f5b74b 63854c10
...@@ -59,7 +59,6 @@ class LM(abc.ABC): ...@@ -59,7 +59,6 @@ class LM(abc.ABC):
class Dataset(abc.ABC): class Dataset(abc.ABC):
@abc.abstractmethod
def __init__(self): def __init__(self):
self.download() self.download()
self._traindocs = None self._traindocs = None
......
...@@ -14,6 +14,7 @@ from . import naturalqs ...@@ -14,6 +14,7 @@ from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import piqa
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
...@@ -39,6 +40,7 @@ TASK_REGISTRY = { ...@@ -39,6 +40,7 @@ TASK_REGISTRY = {
# Order by benchmark/genre? # Order by benchmark/genre?
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
......
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset, rf, mean
from ..utils import sh from ..utils import sh
import os
class PiQA(Dataset): class PiQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
#pass if not os.path.exists('data/piqa'):
#TODO: don't download if files already there #TODO: use best_download
sh(""" sh("""
mkdir -p data/piqa mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
...@@ -25,11 +24,11 @@ class PiQA(Dataset): ...@@ -25,11 +24,11 @@ class PiQA(Dataset):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def load_docs(self, textfilename, labelfilename): def load_docs(self, textfilename, labelfilename):
if labelfilename != None: if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(open(labelfilename, 'r'))) return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
else: else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))] return [json.loads(entry) for entry in list(open(textfilename,'r'))]
...@@ -39,62 +38,40 @@ class PiQA(Dataset): ...@@ -39,62 +38,40 @@ class PiQA(Dataset):
def validation_docs(self): def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst') return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
def test_docs(self): #def test_docs(self):
return self.load_docs('data/piqa/piqa-test.jsonl', None) # return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
#TODO: check if oa uses newline return doc[0]['goal']
return doc['goal'] + ' '
def doc_to_target(self, doc): def doc_to_target(self, doc):
rightanswer = int(doc[1][0]) + 1 #TODO: check if oa uses newline
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]]) rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
Requests which will be sent to the LM. ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2'])
:param doc: return ll_1, ll_2
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a ll_1, ll_2 = results
dict where keys are the names of submetrics and values are the values of
the metric for that one document return {
'acc': (ll_1 > ll_2) == (int(doc[1]) == 0)
:param doc: }
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} 'acc': mean
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} 'acc': True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment