Unverified Commit 6caa0afd authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #300 from jon-tow/hf-dataset-refactor

Refactor `Task` downloading to use `HuggingFace.datasets`
parents 7064d6b9 9434722c
......@@ -19,16 +19,14 @@ we could try this?
Homepage: https://github.com/sylinrl/TruthfulQA
"""
import csv
import json
import inspect
import numpy as np
import sacrebleu
import datasets
import lm_eval.datasets.truthfulqa.truthfulqa
from rouge_score import rouge_scorer, scoring
from lm_eval.base import rf, Task
from pathlib import Path
from best_download import download_file
from ..metrics import mean
from datasets import load_metric
from lm_eval.metrics import mean
_CITATION = """
......@@ -62,15 +60,8 @@ QA_PROMPT = (
class TruthfulQAMultipleChoice(Task):
VERSION = 1
DATASET_PATH = Path('data/truthfulqa/mc')
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
mc_url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
checksum = "6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954"
download_file(mc_url, local_file=str(self.DATASET_PATH / "mc_task.json"), expected_checksum=checksum)
DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "multiple_choice"
def has_training_docs(self):
return False
......@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task):
raise NotImplementedError()
def validation_docs(self):
with open(self.DATASET_PATH / "mc_task.json") as f:
return json.load(f)
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
......@@ -121,7 +111,7 @@ class TruthfulQAMultipleChoice(Task):
return [rf.loglikelihood(ctx, " " + t)[0] for t in targets]
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
return get_lls(doc['mc1_targets']) + get_lls(doc['mc2_targets'])
return get_lls(doc['mc1_targets']["choices"]) + get_lls(doc['mc2_targets']["choices"])
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -139,14 +129,14 @@ class TruthfulQAMultipleChoice(Task):
def mc2(lls):
# Split on the first `0` as everything before it is true (`1`).
split_idx = list(doc['mc2_targets'].values()).index(0)
split_idx = list(doc['mc2_targets']["labels"]).index(0)
# Compute the normalized probability mass for the correct answer.
ll_true, ll_false = lls[:split_idx], lls[split_idx:]
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false))
return sum(p_true)
split_idx = len(doc['mc1_targets'])
split_idx = len(doc['mc1_targets']["choices"])
mc1_lls, mc2_lls = results[:split_idx], results[split_idx:]
return {
"mc1": mc1(mc1_lls),
......@@ -168,19 +158,12 @@ class TruthfulQAMultipleChoice(Task):
class TruthfulQAGeneration(Task):
VERSION = 1
DATASET_PATH = Path('data/truthfulqa/generation')
DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "generation"
def __init__(self):
super().__init__()
self.bleurt = load_metric("bleurt", cache_dir="lm_cache")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
checksum = "8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2"
download_file(url, local_file=str(self.DATASET_PATH / "TruthfulQA.csv"), expected_checksum=checksum)
self.bleurt = datasets.load_metric("bleurt")
def has_training_docs(self):
return False
......@@ -194,36 +177,29 @@ class TruthfulQAGeneration(Task):
def training_docs(self):
raise NotImplementedError()
def _split_multi_answer(self, answers, sep=';'):
answers = answers.strip().split(sep)
split_answers = []
def _format_answers(self, answers):
formatted_answers = []
for answer in answers:
answer = answer.strip()
if len(answer):
# Add a period after all answers.
if answer[-1] != '.':
split_answers.append(answer + '.')
formatted_answers.append(answer + '.')
else:
split_answers.append(answer)
return split_answers
formatted_answers.append(answer)
return formatted_answers
def validation_docs(self):
with open(self.DATASET_PATH / "TruthfulQA.csv", newline='') as csvfile:
doc_reader = csv.DictReader(csvfile)
for doc in doc_reader:
# Ensure that references exist.
if not doc['Correct Answers'] or not doc['Incorrect Answers']:
continue
correct_answers = self._split_multi_answer(doc['Correct Answers'])
if "I have no comment." not in correct_answers:
correct_answers.append("I have no comment.")
incorrect_answers = self._split_multi_answer(doc['Incorrect Answers'])
doc = {
'question': doc['Question'].strip(),
'correct_answers': correct_answers,
'incorrect_answers': incorrect_answers
}
yield doc
for doc in self.dataset["validation"]:
incorrect_answers = self._format_answers(doc['incorrect_answers'])
correct_answers = self._format_answers(doc['correct_answers'])
if "I have no comment." not in correct_answers:
correct_answers.append("I have no comment.")
yield {
'question': doc['question'].strip(),
'correct_answers': correct_answers,
'incorrect_answers': incorrect_answers
}
def test_docs(self):
raise NotImplementedError()
......
......@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
import gzip
import json
import shutil
from pathlib import Path
from best_download import download_file
import inspect
import lm_eval.datasets.unscramble.unscramble
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
......@@ -32,30 +29,10 @@ _CITATION = """
"""
def extract_gzip(gz, to):
with gzip.open(gz, 'rb') as fin:
with open(to, 'wb') as fout:
shutil.copyfileobj(fin, fout)
class WordUnscrambleTask(Task):
VERSION = 0
BASE_PATH = Path("data/unscramble")
FILENAME = None
CHECKSUM = None # SHA256 Checksum.
def __init__(self):
super().__init__()
def download(self):
if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME
if not file.exists():
rawfile = file.parent / (file.name + ".gz")
base_url = "https://raw.githubusercontent.com/openai/gpt-3/master/data"
download_file(f"{base_url}/{self.FILENAME}.gz", local_file=str(rawfile), expected_checksum=self.CHECKSUM)
extract_gzip(gz=rawfile, to=file)
DATASET_PATH = inspect.getfile(lm_eval.datasets.unscramble.unscramble)
DATASET_NAME = None
def has_training_docs(self):
return False
......@@ -67,8 +44,7 @@ class WordUnscrambleTask(Task):
return False
def validation_docs(self):
file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines())
return self.dataset["validation"]
def doc_to_text(self, doc):
return doc["context"]
......@@ -99,25 +75,20 @@ class WordUnscrambleTask(Task):
class Anagrams1(WordUnscrambleTask):
FILENAME = "mid_word_1_anagrams.jsonl"
CHECKSUM = "6768a86896083199de4815d4964cb2f6f1046476cfd80c2a562784f182905979"
DATASET_NAME = "mid_word_1_anagrams"
class Anagrams2(WordUnscrambleTask):
FILENAME = "mid_word_2_anagrams.jsonl"
CHECKSUM = "c3d839d09a7954b78a27cd2cd75d4ed0488656c56ef4dbd741a005343826cb01"
DATASET_NAME = "mid_word_2_anagrams"
class CycleLetters(WordUnscrambleTask):
FILENAME = "cycle_letters_in_word.jsonl"
CHECKSUM = "1689c9002bb8c5988bf5f05e977c9db92f57932c1b5a38998c29ac0dd71e1d42"
DATASET_NAME = "cycle_letters_in_word"
class RandomInsertion(WordUnscrambleTask):
FILENAME = "random_insertion_in_word.jsonl"
CHECKSUM = "72e65d83da53d15752ee0c47379509de149ddbad32d61184e5991df29616b78a"
DATASET_NAME = "random_insertion_in_word"
class ReversedWords(WordUnscrambleTask):
FILENAME = "reversed_words.jsonl"
CHECKSUM = "133a08f875cd6c1ef8608a3233571a773881cc27b1c707de738cc6543439332a"
DATASET_NAME = "reversed_words"
......@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013).
Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a
"""
from . common import HFTask
from lm_eval.base import rf
from ..metrics import mean
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -32,7 +31,7 @@ _CITATION = """
"""
class WebQs(HFTask):
class WebQs(Task):
VERSION = 0
DATASET_PATH = "web_questions"
DATASET_NAME = None
......@@ -46,6 +45,14 @@ class WebQs(HFTask):
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
......
......@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh
from best_download import download_file
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
_CITATION = """
......@@ -64,45 +63,33 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_NAME = "wikitext-2-raw-v1"
def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def has_validation_docs(self):
def has_training_docs(self):
return True
def has_train_docs(self):
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self):
return self.docs_for_split('valid')
def training_docs(self):
return map(self._load_doc, self.dataset["train"])
def train_docs(self):
return self.docs_for_split('train')
def validation_docs(self):
return map(self._load_doc, self.dataset["validation"])
def test_docs(self):
return self.docs_for_split('test')
return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
......@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
import numpy as np
from . common import HFTask
from lm_eval.base import rf
from ..metrics import mean
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -30,7 +29,7 @@ _CITATION = """
"""
class Winogrande(HFTask):
class Winogrande(Task):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
......@@ -46,6 +45,14 @@ class Winogrande(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
......
......@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
"""
import numpy as np
import random
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -37,7 +35,7 @@ _CITATION = """
"""
class WinogradSchemaChallenge273(HFTask):
class WinogradSchemaChallenge273(Task):
VERSION = 0
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
......@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"]
def __init__(self):
super().__init__()
self.data = self.__clean_data()
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
def __clean_data(self):
def _load_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
return doc
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
......@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
return option.replace(pronoun, pronoun.lower())
return option
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
......
......@@ -21,8 +21,7 @@ setuptools.setup(
python_requires='>=3.6',
install_requires=[
"black",
"best_download==0.0.9",
"datasets==1.15.1",
"datasets==2.0.0",
"click>=7.1",
"scikit-learn>=0.24.1",
"torch>=1.7",
......@@ -43,6 +42,7 @@ setuptools.setup(
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
],
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment