Commit 5e31e40e authored by Leo Gao's avatar Leo Gao
Browse files

Implement WikiText

parent efa99cb2
...@@ -322,17 +322,17 @@ class PerplexityTask(Task, abc.ABC): ...@@ -322,17 +322,17 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
assert not ctx assert not ctx
req = rf.loglikelihood_rolling(doc) req = rf.loglikelihood_rolling(self.doc_to_target(doc))
return req return req
def process_results(self, doc, results): def process_results(self, doc, results):
loglikelihood, = results loglikelihood, = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(doc)
bytes = self.count_bytes(self.doc_to_target(doc)) bytes = self.count_bytes(doc)
return { return {
"word_perplexity": (loglikelihood, words), "word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes), "byte_perplexity": (loglikelihood, bytes),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_target(doc))) "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
} }
def aggregation(self): def aggregation(self):
...@@ -342,12 +342,12 @@ class PerplexityTask(Task, abc.ABC): ...@@ -342,12 +342,12 @@ class PerplexityTask(Task, abc.ABC):
"bits_per_byte": weighted_mean "bits_per_byte": weighted_mean
} }
def count_bytes(self, s): def count_bytes(self, doc):
return len(s.encode("utf-8")) return len(doc.encode("utf-8"))
def count_words(self, s): def count_words(self, doc):
""" Downstream tasks with custom word boundaries should override this! """ """ Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s)) return len(re.split(r"\s+", doc))
req_ret_lens = { req_ret_lens = {
......
...@@ -38,6 +38,7 @@ from . import hendrycks_math ...@@ -38,6 +38,7 @@ from . import hendrycks_math
from . import cbt from . import cbt
from . import lambada_cloze from . import lambada_cloze
from . import pile from . import pile
from . import wikitext
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -95,6 +96,7 @@ TASK_REGISTRY = { ...@@ -95,6 +96,7 @@ TASK_REGISTRY = {
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
...@@ -113,7 +115,7 @@ TASK_REGISTRY = { ...@@ -113,7 +115,7 @@ TASK_REGISTRY = {
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
......
from . common import HFTask import os
import re
from lm_eval.base import rf, PerplexityTask
class WikiText103(HFTask): from lm_eval.utils import sh
from best_download import download_file
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
class WikiText(PerplexityTask):
VERSION = 0 VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-103-raw-v1" def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def has_validation_docs(self):
# TODO: implement return True
pass
def doc_to_target(self, doc): def has_train_docs(self):
# TODO: implement return True
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc: def has_test_docs(self):
The document as returned from training_docs, validation_docs, or test_docs. return True
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def docs_for_split(self, split):
"""Take a single document and the LM results and evaluates, returning a ret = []
dict where keys are the names of submetrics and values are the values of for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
the metric for that one document rline = line.replace("= =", "==").replace("= = =", "===").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
:param doc: s = '\n'.join(ret)
The document as returned from training_docs, validation_docs, or test_docs. if s.strip(): yield s
:param results: ret = []
The results of the requests created in construct_requests. ret.append(line)
""" yield '\n'.join(ret)
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') def validation_docs(self):
return self.docs_for_split('valid')
def aggregation(self):
""" def train_docs(self):
:returns: {str: [float] -> float} return self.docs_for_split('train')
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def test_docs(self):
""" return self.docs_for_split('test')
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
class WikiText2(HFTask):
VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-2-raw-v1"
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement return wikitext_detokenizer(doc)
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def count_words(self, doc):
"""Take a single document and the LM results and evaluates, returning a # count number of words in *original doc before detokenization*
dict where keys are the names of submetrics and values are the values of return len(re.split(r"\s+", doc))
the metric for that one document \ No newline at end of file
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment