wikitext.py 3.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf

The WikiText language modeling dataset is a collection of over 100 million tokens 
extracted from the set of verified Good and Featured articles on Wikipedia.

NOTE: This `Task` is based on WikiText-2.

Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/

@misc{merity2016pointer,
      title={Pointer Sentinel Mixture Models}, 
      author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
      year={2016},
      eprint={1609.07843},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
"""
Leo Gao's avatar
Leo Gao committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh

from best_download import download_file


def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


class WikiText(PerplexityTask):
64
    VERSION = 1
Leo Gao's avatar
Leo Gao committed
65
66
67
68

    def download(self):
        if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
            os.makedirs("data/wikitext/", exist_ok=True)
69
            download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
Leo Gao's avatar
Leo Gao committed
70
            sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
Anish Thite's avatar
Anish Thite committed
71

Leo Gao's avatar
Leo Gao committed
72
73
    def has_validation_docs(self):
        return True
74

Leo Gao's avatar
Leo Gao committed
75
76
    def has_train_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
77

Leo Gao's avatar
Leo Gao committed
78
79
    def has_test_docs(self):
        return True
Leo Gao's avatar
Leo Gao committed
80
    
Leo Gao's avatar
Leo Gao committed
81
82
83
    def docs_for_split(self, split):
        ret = []
        for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
Leo Gao's avatar
Leo Gao committed
84
            rline = line.replace("= = =", "===").replace("= =", "==").strip()
Leo Gao's avatar
Leo Gao committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            if rline.startswith('= ') and rline.strip().endswith(' ='):
                s = '\n'.join(ret)
                if s.strip(): yield s
                ret = []
            ret.append(line)
        yield '\n'.join(ret)

    def validation_docs(self):
        return self.docs_for_split('valid')

    def train_docs(self):
        return self.docs_for_split('train')

    def test_docs(self):
        return self.docs_for_split('test')
Leo Gao's avatar
Leo Gao committed
100
101

    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
102
        return wikitext_detokenizer(doc)
Leo Gao's avatar
Leo Gao committed
103
    
Leo Gao's avatar
Leo Gao committed
104
105
    def count_words(self, doc):
        # count number of words in *original doc before detokenization*
106
        return len(re.split(r"\s+", doc))