wikitext.py 2.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf

The WikiText language modeling dataset is a collection of over 100 million tokens 
extracted from the set of verified Good and Featured articles on Wikipedia.

NOTE: This `Task` is based on WikiText-2.

Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
Leo Gao's avatar
Leo Gao committed
12
import re
Jonathan Tow's avatar
Jonathan Tow committed
13
14
15
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
Leo Gao's avatar
Leo Gao committed
16
17


18
19
20
21
22
23
24
25
26
27
28
29
_CITATION = """
@misc{merity2016pointer,
    title={Pointer Sentinel Mixture Models}, 
    author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
    year={2016},
    eprint={1609.07843},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""


Leo Gao's avatar
Leo Gao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def wikitext_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")

    return string


class WikiText(PerplexityTask):
65
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
66
67
    DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
    DATASET_NAME = "wikitext-2-raw-v1"
Leo Gao's avatar
Leo Gao committed
68

Jonathan Tow's avatar
Jonathan Tow committed
69
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
70
        return True
71

Jonathan Tow's avatar
Jonathan Tow committed
72
    def has_validation_docs(self):
Leo Gao's avatar
Leo Gao committed
73
        return True
Leo Gao's avatar
Leo Gao committed
74

Leo Gao's avatar
Leo Gao committed
75
76
77
    def has_test_docs(self):
        return True

Jonathan Tow's avatar
Jonathan Tow committed
78
    def training_docs(self):
79
        return map(self._process_doc, self.dataset["train"])
Leo Gao's avatar
Leo Gao committed
80

Jonathan Tow's avatar
Jonathan Tow committed
81
    def validation_docs(self):
82
        return map(self._process_doc, self.dataset["validation"])
Leo Gao's avatar
Leo Gao committed
83
84

    def test_docs(self):
85
        return map(self._process_doc, self.dataset["test"])
Jonathan Tow's avatar
Jonathan Tow committed
86

87
    def _process_doc(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
88
        return doc["page"]
Leo Gao's avatar
Leo Gao committed
89
90

    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
91
        return wikitext_detokenizer(doc)
Jonathan Tow's avatar
Jonathan Tow committed
92

Leo Gao's avatar
Leo Gao committed
93
94
    def count_words(self, doc):
        # count number of words in *original doc before detokenization*
95
        return len(re.split(r"\s+", doc))