wikitext.py 3.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf

The WikiText language modeling dataset is a collection of over 100 million tokens 
extracted from the set of verified Good and Featured articles on Wikipedia.

NOTE: This `Task` is based on WikiText-2.

Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
Leo Gao's avatar
Leo Gao committed
12
13
14
15
16
17
18
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh
from best_download import download_file


19
20
21
22
23
24
25
26
27
28
29
30
_CITATION = """
@misc{merity2016pointer,
    title={Pointer Sentinel Mixture Models}, 
    author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
    year={2016},
    eprint={1609.07843},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""


Leo Gao's avatar
Leo Gao committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


class WikiText(PerplexityTask):
66
    VERSION = 1
Leo Gao's avatar
Leo Gao committed
67
68
69
70

    def download(self):
        if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
            os.makedirs("data/wikitext/", exist_ok=True)
71
            download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
Leo Gao's avatar
Leo Gao committed
72
            sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
Anish Thite's avatar
Anish Thite committed
73

Leo Gao's avatar
Leo Gao committed
74
75
    def has_validation_docs(self):
        return True
76

Leo Gao's avatar
Leo Gao committed
77
78
    def has_train_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
79

Leo Gao's avatar
Leo Gao committed
80
81
    def has_test_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
82
    
Leo Gao's avatar
Leo Gao committed
83
84
85
    def docs_for_split(self, split):
        ret = []
        for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
Leo Gao's avatar
Leo Gao committed
86
            rline = line.replace("= = =", "===").replace("= =", "==").strip()
Leo Gao's avatar
Leo Gao committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            if rline.startswith('= ') and rline.strip().endswith(' ='):
                s = '\n'.join(ret)
                if s.strip(): yield s
                ret = []
            ret.append(line)
        yield '\n'.join(ret)

    def validation_docs(self):
        return self.docs_for_split('valid')

    def train_docs(self):
        return self.docs_for_split('train')

    def test_docs(self):
        return self.docs_for_split('test')
Leo Gao's avatar
Leo Gao committed
102
103

    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
104
        return wikitext_detokenizer(doc)
Leo Gao's avatar
Leo Gao committed
105
    
Leo Gao's avatar
Leo Gao committed
106
107
    def count_words(self, doc):
        # count number of words in *original doc before detokenization*
108
        return len(re.split(r"\s+", doc))