wikitext.py 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf

The WikiText language modeling dataset is a collection of over 100 million tokens 
extracted from the set of verified Good and Featured articles on Wikipedia.

NOTE: This `Task` is based on WikiText-2.

Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
Leo Gao's avatar
Leo Gao committed
12
import re
Jonathan Tow's avatar
Jonathan Tow committed
13
14
15
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
Leo Gao's avatar
Leo Gao committed
16
17


18
19
20
21
22
23
24
25
26
27
_CITATION = """
@misc{merity2016pointer,
    title={Pointer Sentinel Mixture Models}, 
    author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
    year={2016},
    eprint={1609.07843},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
Leo Gao's avatar
Leo Gao committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


class WikiText(PerplexityTask):
65
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
66
67
    DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
    DATASET_NAME = "wikitext-2-raw-v1"
Leo Gao's avatar
Leo Gao committed
68

Jonathan Tow's avatar
Jonathan Tow committed
69
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
70
        return True
71

Jonathan Tow's avatar
Jonathan Tow committed
72
    def has_validation_docs(self):
Leo Gao's avatar
Leo Gao committed
73
        return True
Leo Gao's avatar
Leo Gao committed
74

Leo Gao's avatar
Leo Gao committed
75
76
77
    def has_test_docs(self):
        return True

Jonathan Tow's avatar
Jonathan Tow committed
78
79
    def training_docs(self):
        return map(self._load_doc, self.dataset["train"])
Leo Gao's avatar
Leo Gao committed
80

Jonathan Tow's avatar
Jonathan Tow committed
81
82
    def validation_docs(self):
        return map(self._load_doc, self.dataset["validation"])
Leo Gao's avatar
Leo Gao committed
83
84

    def test_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
85
86
87
88
        return map(self._load_doc, self.dataset["test"])

    def _load_doc(self, doc):
        return doc["page"]
Leo Gao's avatar
Leo Gao committed
89
90

    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
91
        return wikitext_detokenizer(doc)
92
93
94
95

    def should_decontaminate(self):
        return True

Leo Gao's avatar
Leo Gao committed
96
97
    def count_words(self, doc):
        # count number of words in *original doc before detokenization*
98
        return len(re.split(r"\s+", doc))