wikitext.py 2.99 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh

from best_download import download_file


def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


class WikiText(PerplexityTask):
Leo Gao's avatar
Leo Gao committed
44
    VERSION = 0
Leo Gao's avatar
Leo Gao committed
45
46
47
48
49
50

    def download(self):
        if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
            os.makedirs("data/wikitext/", exist_ok=True)
            download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
            sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
Anish Thite's avatar
Anish Thite committed
51
52

    def fewshot_description(self):
Leo Gao's avatar
Leo Gao committed
53
        # TODO: figure out fewshot description
Anish Thite's avatar
Anish Thite committed
54
55
        return ""

Leo Gao's avatar
Leo Gao committed
56
57
    def has_validation_docs(self):
        return True
58

Leo Gao's avatar
Leo Gao committed
59
60
    def has_train_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
61

Leo Gao's avatar
Leo Gao committed
62
63
    def has_test_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
64
    
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def docs_for_split(self, split):
        ret = []
        for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
            rline = line.replace("= =", "==").replace("= = =", "===").strip()
            if rline.startswith('= ') and rline.strip().endswith(' ='):
                s = '\n'.join(ret)
                if s.strip(): yield s
                ret = []
            ret.append(line)
        yield '\n'.join(ret)

    def validation_docs(self):
        return self.docs_for_split('valid')

    def train_docs(self):
        return self.docs_for_split('train')

    def test_docs(self):
        return self.docs_for_split('test')
Leo Gao's avatar
Leo Gao committed
84
85

    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
86
        return wikitext_detokenizer(doc)
Leo Gao's avatar
Leo Gao committed
87
    
Leo Gao's avatar
Leo Gao committed
88
89
90
    def count_words(self, doc):
        # count number of words in *original doc before detokenization*
        return len(re.split(r"\s+", doc))