Unverified Commit a2f5b74b authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #101 from EleutherAI/lambada

Implement LAMBADA (#6)
parents 7031c324 822fcc6f
...@@ -13,6 +13,7 @@ from . import squad ...@@ -13,6 +13,7 @@ from . import squad
from . import naturalqs from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
...@@ -37,6 +38,8 @@ TASK_REGISTRY = { ...@@ -37,6 +38,8 @@ TASK_REGISTRY = {
# Order by benchmark/genre? # Order by benchmark/genre?
"lambada": lambada.LAMBADA,
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
from lm_eval.base import Dataset from lm_eval.base import Dataset, rf, mean
from lm_eval.utils import sh from lm_eval.utils import sh
import json import json
import requests import requests
import ftfy import ftfy
import math
from best_download import download_file
class Lambada(Dataset): class LAMBADA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
sh("mkdir -p data/lambada") sh("mkdir -p data/lambada")
with open("data/lambada/lambada_test.json", 'w') as f: download_file(
req = requests.get("https://storage.googleapis.com/gpt-2/data/lambada_test.jsonl") "https://storage.googleapis.com/gpt-2/data/lambada_test.jsonl",
req.raise_for_status() "data/lambada/lambada_test.jsonl",
jsons = [json.loads(l) for l in req.iter_lines()] "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
texts = [ftfy.fix_text(j['text'], normalization='NFKC') for j in jsons] )
json.dump(texts, f)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -32,61 +31,42 @@ class Lambada(Dataset): ...@@ -32,61 +31,42 @@ class Lambada(Dataset):
def validation_docs(self): def validation_docs(self):
pass pass
def load_doc(self, myjson):
return [doc for doc in myjson]
def test_docs(self): def test_docs(self):
myjson = json.load(open("data/lambada/lambada_test.json")) with open("data/lambada/lambada_test.jsonl") as fh:
return self.load_doc(myjson) for line in fh:
yield json.loads(line)
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0]
def doc_to_text(self, doc, include_target=True): def doc_to_target(self, doc):
# TODO: implement. return " " + doc['text'].rsplit(' ', 1)[1]
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description # TODO: figure out description
return "" return ""
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ll, is_greedy = rf.loglikelihood(doc, self.doc_to_target(doc))
Requests which will be sent to the LM.
:param doc: return ll, is_greedy
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a ll, is_greedy = results
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
'perplexity': math.exp(-ll),
'accuracy': int(is_greedy)
}
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} 'perplexity': mean,
A dictionary where keys are the names of submetrics and values are 'accuracy': mean
functions that aggregate a list of metrics }
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} 'perplexity': False,
A dictionary where keys are the names of submetrics and values are 'accuracy': True
whether a higher value of the submetric is better }
""" \ No newline at end of file
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment