Commit 47382717 authored by Ben Wang's avatar Ben Wang
Browse files

add cloze variant of lambada task for allegeldy improved few shot results

parent efbe6e7f
......@@ -35,6 +35,7 @@ from . import unscramble
from . import logiqa
from . import hendrycks_test
from . import hendrycks_math
from . import lambada_cloze
########################################
# Translation tasks
......@@ -91,6 +92,7 @@ TASK_REGISTRY = {
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
"piqa": piqa.PiQA,
# Science related
......
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
class LAMBADA_cloze(Task):
def download(self):
sh("mkdir -p data/lambada")
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
"data/lambada/lambada_test.jsonl",
"4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
pass
def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh:
for line in fh:
yield json.loads(line)
def test_docs(self):
pass
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1]
def fewshot_description(self):
return "Fill in blank:\n"
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
return {
'ppl': ll,
'acc': int(is_greedy)
}
def aggregation(self):
return {
'ppl': perplexity,
'acc': mean
}
def higher_is_better(self):
return {
'ppl': False,
'acc': True
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment