Commit 5e31e40e authored by Leo Gao's avatar Leo Gao
Browse files

Implement WikiText

parent efa99cb2
......@@ -322,17 +322,17 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx):
assert not ctx
req = rf.loglikelihood_rolling(doc)
req = rf.loglikelihood_rolling(self.doc_to_target(doc))
return req
def process_results(self, doc, results):
loglikelihood, = results
words = self.count_words(self.doc_to_target(doc))
bytes = self.count_bytes(self.doc_to_target(doc))
words = self.count_words(doc)
bytes = self.count_bytes(doc)
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_target(doc)))
"bits_per_byte": (-loglikelihood, self.count_bytes(doc))
}
def aggregation(self):
......@@ -342,12 +342,12 @@ class PerplexityTask(Task, abc.ABC):
"bits_per_byte": weighted_mean
}
def count_bytes(self, s):
return len(s.encode("utf-8"))
def count_bytes(self, doc):
return len(doc.encode("utf-8"))
def count_words(self, s):
def count_words(self, doc):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
return len(re.split(r"\s+", doc))
req_ret_lens = {
......
......@@ -38,6 +38,7 @@ from . import hendrycks_math
from . import cbt
from . import lambada_cloze
from . import pile
from . import wikitext
########################################
# Translation tasks
......@@ -95,6 +96,7 @@ TASK_REGISTRY = {
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
......@@ -113,7 +115,7 @@ TASK_REGISTRY = {
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, # not implemented yet
"hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet
"squad2": squad.SQuAD2,
......
from . common import HFTask
class WikiText103(HFTask):
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh
from best_download import download_file
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
class WikiText(PerplexityTask):
VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-103-raw-v1"
def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def has_validation_docs(self):
return True
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
def has_train_docs(self):
return True
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def has_test_docs(self):
return True
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
class WikiText2(HFTask):
VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-2-raw-v1"
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= =", "==").replace("= = =", "===").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self):
return self.docs_for_split('valid')
def train_docs(self):
return self.docs_for_split('train')
def test_docs(self):
return self.docs_for_split('test')
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return wikitext_detokenizer(doc)
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment