pile.py 1.46 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import lm_dataformat
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file


class PilePerplexityTask(PerplexityTask, abc.ABC):

    PILE_SET_NAME = None
    VAL_PATH = 'data/pile/val.jsonl.zst'
    TEST_PATH = 'data/pile/test.jsonl.zst'

    def download(self):
        os.makedirs("data/pile/", exist_ok=True)
        if not os.path.exists(self.VAL_PATH):
            download_file("https://the-eye.eu/public/AI/pile/val.jsonl.zst", self.VAL_PATH)
        if not os.path.exists(self.TEST_PATH):
            download_file("https://the-eye.eu/public/AI/pile/test.jsonl.zst", self.TEST_PATH)

    def validation_docs(self):
        rdr = lm_dataformat.Reader(self.VAL_PATH)
        for doc, metadata in rdr.stream_data(get_meta=True):
            if metadata["pile_set_name"] == self.PILE_SET_NAME:
                yield doc

    def test_docs(self):
        rdr = lm_dataformat.Reader(self.TEST_PATH)
        for doc, metadata in rdr.stream_data(get_meta=True):
            if metadata["pile_set_name"] == self.PILE_SET_NAME:
                yield doc

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True


class PileEnronPerplexityTask(PilePerplexityTask):
    PILE_SET_NAME = "Enron Emails"


class PileUbuntuPerplexityTask(PilePerplexityTask):
    PILE_SET_NAME = "Ubuntu IRC"