Commit 7c9da714 authored by Jonathan Tow's avatar Jonathan Tow Committed by Jon Tow
Browse files

Refactor `Task` download

parent 7064d6b9
...@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import os
import re import re
from lm_eval.base import rf, PerplexityTask import inspect
from lm_eval.utils import sh import lm_eval.datasets.wikitext.wikitext
from best_download import download_file from lm_eval.base import PerplexityTask
_CITATION = """ _CITATION = """
...@@ -64,45 +63,33 @@ def wikitext_detokenizer(string): ...@@ -64,45 +63,33 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_NAME = "wikitext-2-raw-v1"
def download(self): def has_training_docs(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def has_validation_docs(self):
return True return True
def has_train_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return True
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self): def training_docs(self):
return self.docs_for_split('valid') return map(self._load_doc, self.dataset["train"])
def train_docs(self): def validation_docs(self):
return self.docs_for_split('train') return map(self._load_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return self.docs_for_split('test') return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
...@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847 ...@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
""" """
import numpy as np import numpy as np
from . common import HFTask from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
_CITATION = """ _CITATION = """
...@@ -30,7 +29,7 @@ _CITATION = """ ...@@ -30,7 +29,7 @@ _CITATION = """
""" """
class Winogrande(HFTask): class Winogrande(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
...@@ -46,6 +45,14 @@ class Winogrande(HFTask): ...@@ -46,6 +45,14 @@ class Winogrande(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
......
...@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0 ...@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
""" """
import numpy as np import numpy as np
import random from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -37,7 +35,7 @@ _CITATION = """ ...@@ -37,7 +35,7 @@ _CITATION = """
""" """
class WinogradSchemaChallenge273(HFTask): class WinogradSchemaChallenge273(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
...@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
upper_pronouns = ["A", "An", "The", "She", "He", upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"] "It", "They", "My", "His", "Her", "Their"]
def __init__(self): def has_training_docs(self):
super().__init__() return False
self.data = self.__clean_data()
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
def __clean_data(self): def _load_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly. # The HF implementation of `wsc273` is not `partial evaluation` friendly.
data = [] doc["text"] = doc["text"].replace(" ", " ")
for doc in self.data["test"]: doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["text"] = doc["text"].replace(" ", " ") doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
doc["options"][0] = self.__normalize_option(doc, doc["options"][0]) return doc
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, doc, option): def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options. # Append `'s` to possessive determiner based options.
...@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
return option.replace(pronoun, pronoun.lower()) return option.replace(pronoun, pronoun.lower())
return option return option
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are # NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset. # not available for this test-set-only dataset.
......
...@@ -21,8 +21,7 @@ setuptools.setup( ...@@ -21,8 +21,7 @@ setuptools.setup(
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[ install_requires=[
"black", "black",
"best_download==0.0.9", "datasets==2.0.0",
"datasets==1.15.1",
"click>=7.1", "click>=7.1",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"torch>=1.7", "torch>=1.7",
...@@ -43,6 +42,7 @@ setuptools.setup( ...@@ -43,6 +42,7 @@ setuptools.setup(
"openai==0.6.4", "openai==0.6.4",
"jieba==0.42.1", "jieba==0.42.1",
"nagisa==0.2.7", "nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
], ],
dependency_links=[ dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment