lambada.py 1.14 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
from lm_eval.base import Dataset
sdtblck's avatar
sdtblck committed
2
from lm_eval.utils import sh
sdtblck's avatar
sdtblck committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import json
import requests
import ftfy


class Lambada(Dataset):

    def download(self):
        sh("mkdir -p data/lambada")
        with open("data/lambada/lambada_test.json", 'w') as f:
            req = requests.get("https://storage.googleapis.com/gpt-2/data/lambada_test.jsonl")
            req.raise_for_status()
            jsons = [json.loads(l) for l in req.iter_lines()]
            texts = [ftfy.fix_text(j['text'], normalization='NFKC') for j in jsons]
            json.dump(texts, f)

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        pass

    def validation_docs(self):
        pass

    def load_doc(self, myjson):
        return [doc['text'] for doc in myjson]

    def test_docs(self):
        myjson = json.load(open("data/lambada/lambada_test.json"))
        return self.load_doc(myjson)

    def doc_to_text(self, doc, include_target=True):
        pass

    def evaluate(self, docs, lm, provide_description, num_fewshot):
        pass