Unverified Commit 7ba8c183 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Use more inheritance

parent 47382717
......@@ -2,38 +2,11 @@ import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
class LAMBADA_cloze(Task):
def download(self):
sh("mkdir -p data/lambada")
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
"data/lambada/lambada_test.jsonl",
"4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
pass
def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh:
for line in fh:
yield json.loads(line)
def test_docs(self):
pass
class LAMBADA_cloze(LAMBADA):
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
......@@ -42,28 +15,3 @@ class LAMBADA_cloze(Task):
def fewshot_description(self):
return "Fill in blank:\n"
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
return {
'ppl': ll,
'acc': int(is_greedy)
}
def aggregation(self):
return {
'ppl': perplexity,
'acc': mean
}
def higher_is_better(self):
return {
'ppl': False,
'acc': True
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment