Commit 7c9da714 authored by Jonathan Tow's avatar Jonathan Tow Committed by Jon Tow
Browse files

Refactor `Task` download

parent 7064d6b9
......@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh
from best_download import download_file
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
_CITATION = """
......@@ -64,45 +63,33 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_NAME = "wikitext-2-raw-v1"
def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def has_validation_docs(self):
def has_training_docs(self):
return True
def has_train_docs(self):
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self):
return self.docs_for_split('valid')
def training_docs(self):
return map(self._load_doc, self.dataset["train"])
def train_docs(self):
return self.docs_for_split('train')
def validation_docs(self):
return map(self._load_doc, self.dataset["validation"])
def test_docs(self):
return self.docs_for_split('test')
return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
......@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
import numpy as np
from . common import HFTask
from lm_eval.base import rf
from ..metrics import mean
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -30,7 +29,7 @@ _CITATION = """
"""
class Winogrande(HFTask):
class Winogrande(Task):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
......@@ -46,6 +45,14 @@ class Winogrande(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
......
......@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
"""
import numpy as np
import random
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -37,7 +35,7 @@ _CITATION = """
"""
class WinogradSchemaChallenge273(HFTask):
class WinogradSchemaChallenge273(Task):
VERSION = 0
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
......@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"]
def __init__(self):
super().__init__()
self.data = self.__clean_data()
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
def __clean_data(self):
def _load_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
return doc
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
......@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
return option.replace(pronoun, pronoun.lower())
return option
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
......
......@@ -21,8 +21,7 @@ setuptools.setup(
python_requires='>=3.6',
install_requires=[
"black",
"best_download==0.0.9",
"datasets==1.15.1",
"datasets==2.0.0",
"click>=7.1",
"scikit-learn>=0.24.1",
"torch>=1.7",
......@@ -43,6 +42,7 @@ setuptools.setup(
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
],
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment