Commit 7c9da714 authored by Jonathan Tow's avatar Jonathan Tow Committed by Jon Tow
Browse files

Refactor `Task` download

parent 7064d6b9
......@@ -7,14 +7,12 @@ modified from Chinese high school English listening comprehension test data.
Homepage: https://github.com/Nealcly/MuTual
"""
import json
import zipfile
import shutil
import numpy as np
from pathlib import Path
import inspect
import lm_eval.datasets.mutual.mutual
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """
......@@ -30,29 +28,10 @@ _CITATION = """
class MuTualBase(Task):
VERSION = 1
BASE_PATH = Path("data/mutual")
DATASET_PATH = inspect.getfile(lm_eval.datasets.mutual.mutual)
DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D']
def __init__(self):
super().__init__()
def download(self):
if self.BASE_PATH.exists():
return
Path.mkdir(self.BASE_PATH, parents=True)
master_zip = Path("data/master.zip")
download_file(
"https://github.com/Nealcly/MuTual/archive/master.zip",
local_file=str(master_zip),
expected_checksum="bb325cf6c672f0f02699993a37138b0fa0af6fcfc77ec81dfbe46add4d7b29f9")
with zipfile.ZipFile(master_zip, 'r') as zip:
zip.extractall("data")
Path("data/MuTual-master/data").rename(str(self.BASE_PATH))
# Remove left over files and directories.
master_zip.unlink()
shutil.rmtree("data/MuTual-master")
def has_training_docs(self):
return True
......@@ -62,18 +41,11 @@ class MuTualBase(Task):
def has_test_docs(self):
return False
def _load_docs(self, path):
for file in sorted(path.iterdir()):
if file.suffix != ".txt":
continue
with open(file, 'r', encoding='utf-8') as f:
yield json.load(f)
def training_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "train")
return self.dataset["train"]
def validation_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "dev")
return self.dataset["validation"]
def test_docs(self):
return NotImplemented
......@@ -134,8 +106,8 @@ class MuTualBase(Task):
class MuTual(MuTualBase):
DATASET_NAME = Path("mutual")
DATASET_NAME = "mutual"
class MuTualPlus(MuTualBase):
DATASET_NAME = Path("mutual_plus")
DATASET_NAME = "mutual_plus"
......@@ -15,8 +15,7 @@ not even bother with the train set.
Homepage: https://ai.google.com/research/NaturalQuestions
"""
import random
from . common import HFTask
from lm_eval.base import Task
from itertools import islice
......@@ -30,7 +29,7 @@ _CITATION = """
"""
class NaturalQs(HFTask):
class NaturalQs(Task):
VERSION = 0
DATASET_PATH = "natural_questions"
DATASET_NAME = None
......@@ -47,7 +46,12 @@ class NaturalQs(HFTask):
def training_docs(self):
# Cache training for faster few-shot.
# Data is too large to fit in memory.
return self.data["train"]
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit.
......
......@@ -15,7 +15,6 @@ based algorithm and a word co-occurrence algorithm.
Homepage: https://allenai.org/data/open-book-qa
"""
from lm_eval.base import MultipleChoiceTask
from .common import HFTask
_CITATION = """
......@@ -28,7 +27,7 @@ _CITATION = """
"""
class OpenBookQA(HFTask, MultipleChoiceTask):
class OpenBookQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "openbookqa"
DATASET_NAME = "main"
......@@ -42,6 +41,17 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return map(self._convert_standard, self._training_docs)
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
def _convert_standard(self, doc):
out_doc = {
"id": doc["id"],
......
......@@ -10,15 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
import os
import lm_dataformat
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file
import inspect
import lm_eval.datasets.pile.pile
from lm_eval.base import PerplexityTask
_CITATION = """
......@@ -31,32 +25,10 @@ _CITATION = """
"""
class PilePerplexityTask(PerplexityTask, abc.ABC):
class PilePerplexityTask(PerplexityTask):
VERSION = 1
PILE_SET_NAME = None
VAL_PATH = 'data/pile/val.jsonl.zst'
TEST_PATH = 'data/pile/test.jsonl.zst'
def download(self):
# TODO: separate pile val/test out by component so we don't have to scan the entire file once per set
if not os.path.exists("data/pile/test.jsonl.zst"):
# todo use new best_download fallback api
os.makedirs("data/pile/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/pile/val.jsonl.zst", local_file=self.VAL_PATH, expected_checksum="264c875d8bbd355d8daa9d032b75fd8fb91606218bb84dd1155b203fcd5fab92")
download_file("http://eaidata.bmk.sh/data/pile/test.jsonl.zst", local_file=self.TEST_PATH, expected_checksum="0bb28c52d0b5596d389bf179ce2d43bf7f7ffae76b0d2d20b180c97f62e0975e")
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def test_docs(self):
rdr = lm_dataformat.Reader(self.TEST_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
DATASET_PATH = inspect.getfile(lm_eval.datasets.pile.pile)
DATASET_NAME = None
def has_validation_docs(self):
return True
......@@ -64,90 +36,98 @@ class PilePerplexityTask(PerplexityTask, abc.ABC):
def has_test_docs(self):
return True
def validation_docs(self):
for doc in self.dataset["validation"]:
yield doc["text"]
def test_docs(self):
for doc in self.dataset["test"]:
yield doc["text"]
class PileArxiv(PilePerplexityTask):
PILE_SET_NAME = "ArXiv"
DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask):
PILE_SET_NAME = "Books3"
DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask):
PILE_SET_NAME = "BookCorpus2"
DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask):
PILE_SET_NAME = "DM Mathematics"
DATASET_NAME = "pile_dm-mathematics"
class PileEnron(PilePerplexityTask):
PILE_SET_NAME = "Enron Emails"
DATASET_NAME = "pile_enron"
class PileEuroparl(PilePerplexityTask):
PILE_SET_NAME = "EuroParl"
DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask):
PILE_SET_NAME = "FreeLaw"
DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask):
PILE_SET_NAME = "Github"
DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask):
PILE_SET_NAME = "Gutenberg (PG-19)"
DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask):
PILE_SET_NAME = "HackerNews"
DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask):
PILE_SET_NAME = "NIH ExPorter"
DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask):
PILE_SET_NAME = "OpenSubtitles"
DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask):
PILE_SET_NAME = "OpenWebText2"
DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask):
PILE_SET_NAME = "PhilPapers"
DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask):
PILE_SET_NAME = "Pile-CC"
DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask):
PILE_SET_NAME = "PubMed Abstracts"
DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask):
PILE_SET_NAME = "PubMed Central"
DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask):
PILE_SET_NAME = "StackExchange"
DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask):
PILE_SET_NAME = "USPTO Backgrounds"
DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask):
PILE_SET_NAME = "Ubuntu IRC"
DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask):
PILE_SET_NAME = "Wikipedia (en)"
DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask):
PILE_SET_NAME = "YoutubeSubtitles"
DATASET_NAME = "pile_youtubesubtitles"
......@@ -9,10 +9,7 @@ actually learning about the world?
Homepage: https://yonatanbisk.com/piqa/
"""
import numpy as np
from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
from . common import HFTask
from lm_eval.base import MultipleChoiceTask
_CITATION = """
......@@ -29,7 +26,7 @@ _CITATION = """
"""
class PiQA(HFTask, MultipleChoiceTask):
class PiQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "piqa"
DATASET_NAME = None
......@@ -43,6 +40,14 @@ class PiQA(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return map(self._convert_standard, self._training_docs)
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
def _convert_standard(self, doc):
out_doc = {
"goal": doc["goal"],
......
......@@ -15,7 +15,6 @@ have been trained on data not specifically collected to succeed on PROST."
Homepage: https://github.com/nala-cub/prost
"""
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """
......@@ -36,7 +35,7 @@ _CITATION = """
"""
class PROST(HFTask, MultipleChoiceTask):
class PROST(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "corypaik/prost"
DATASET_NAME = None
......@@ -50,6 +49,9 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(
......
......@@ -16,9 +16,8 @@ and (4) a yes/no/maybe answer which summarizes the conclusion.
Homepage: https://pubmedqa.github.io/
"""
import numpy as np
from .common import HFTask
from lm_eval.base import rf
from ..metrics import mean
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -32,7 +31,7 @@ _CITATION = """
"""
class Pubmed_QA(HFTask):
class Pubmed_QA(Task):
VERSION = 0
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
......@@ -49,7 +48,7 @@ class Pubmed_QA(HFTask):
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.data["train"]
return self.dataset["train"]
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
......
......@@ -13,9 +13,6 @@ and Entrance Exam.
Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php
"""
import os
import xml.etree.ElementTree as ET
from best_download import download_file
from lm_eval.base import MultipleChoiceTask
......@@ -31,35 +28,8 @@ _CITATION = """
class QA4MRE(MultipleChoiceTask):
VERSION = 0
YEAR = None
def download(self):
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.makedirs("data/qa4mre", exist_ok=True)
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
local_file=f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
expected_checksum=sha256sums[year],
)
DATASET_PATH = "qa4mre"
DATASET_NAME = None
def has_training_docs(self):
return False
......@@ -70,39 +40,31 @@ class QA4MRE(MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, question):
choices = [i.text for i in question.iter('answer')]
def test_docs(self):
# `qa4mre` only has train data so we use it for the test docs.
return map(self._convert_standard, self.dataset["train"])
def _convert_standard(self, doc):
choices = doc["answer_options"]["answer_str"]
out_doc = {
"query" : question.find('q_str').text,
"choices": choices,
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1,
"source": doc["document_str"].strip().replace("\'", "'"),
"query": doc["question_str"],
"choices": choices,
"gold": int(doc["correct_answer_id"]) - 1,
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
class QA4MRE_2011(QA4MRE):
YEAR = 2011
DATASET_NAME = "2011.main.EN"
class QA4MRE_2012(QA4MRE):
YEAR = 2012
DATASET_NAME = "2012.main.EN"
class QA4MRE_2013(QA4MRE):
YEAR = 2013
DATASET_NAME = "2013.main.EN"
......@@ -11,13 +11,10 @@ provide supporting evidence to answers.
Homepage: https://allenai.org/data/qasper
"""
from collections import Counter
from math import exp
import random
import re
import string
from lm_eval.base import rf
from lm_eval.base import rf, Task
from lm_eval.metrics import f1_score, mean
from .common import HFTask
_CITATION = """
......@@ -104,11 +101,20 @@ def token_f1_score(prediction, ground_truth):
return f1
class QASPER(HFTask):
class QASPER(Task):
VERSION = 0
DATASET_PATH = "qasper"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def doc_to_text(self, doc):
return (
"TITLE: "
......@@ -130,11 +136,11 @@ class QASPER(HFTask):
return " " + answer
def training_docs(self):
for doc in self.data["train"]:
for doc in self.dataset["train"]:
yield from self.process_doc(doc)
def validation_docs(self):
for doc in self.data["train"]:
for doc in self.dataset["validation"]:
yield from self.process_doc(doc)
def process_doc(self, doc):
......
......@@ -10,10 +10,9 @@ a teacher who answers the questions by providing short excerpts (spans) from the
Homepage: https://quac.ai/
"""
import json
import os
import inspect
import lm_eval.datasets.quac.quac
from lm_eval.base import Task
from ..utils import sh
_CITATION = """
......@@ -28,18 +27,8 @@ _CITATION = """
class QuAC(Task):
VERSION = 0
def __init__(self):
super().__init__()
def download(self):
if not os.path.exists('data/quac'):
# TODO: convert to use best_download
sh("""
mkdir -p data/quac
wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json
wget https://s3.amazonaws.com/my89public/quac/val_v0.2.json -O data/quac/val_v0.2.json
""")
DATASET_PATH = inspect.getfile(lm_eval.datasets.quac.quac)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -51,28 +40,18 @@ class QuAC(Task):
return False
def training_docs(self):
myjson = json.load(open('data/quac/train_v0.2.json'))['data']
return self.load_doc(myjson)
return map(self._convert_standard, self.dataset["train"])
def validation_docs(self):
myjson = json.load(open('data/quac/val_v0.2.json'))['data']
return self.load_doc(myjson)
return map(self._convert_standard, self.dataset["validation"])
def test_docs(self):
raise NotImplementedError("QuAC has no test docs.")
def load_doc(self, myjson):
docs = []
for item in myjson:
title = item['title'] + ' - ' + item['section_title']
paragraph = item['paragraphs'][0]['context'].replace("CANNOTANSWER", "")
qas = item['paragraphs'][0]['qas']
qa_pairs = [(qa['question'], qa['answers'][0]['text']) for qa in qas]
for (question, answer) in qa_pairs:
doc = { 'title': title, 'paragraph': paragraph, 'question': question, 'answer': answer }
docs.append(doc)
return docs
def _convert_standard(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title']
return doc
def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: '
......@@ -88,7 +67,7 @@ class QuAC(Task):
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
......
......@@ -12,9 +12,8 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/
import collections
import datasets
import numpy as np
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -35,16 +34,14 @@ class each:
return list(map(self.f, other))
class RACE(HFTask):
VERSION = 0
class RACE(Task):
VERSION = 1
DATASET_PATH = "race"
DATASET_NAME = "high"
cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
assert datasets.__version__ == "1.15.1", "RACE requires datasets==1.15.1!"
def has_training_docs(self):
return True
......
......@@ -7,7 +7,8 @@ multiple-choice analogy questions; 5 choices per question.
Homepage: https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art)
"""
import os
import inspect
import lm_eval.datasets.sat_analogies.sat_analogies
from lm_eval.base import MultipleChoiceTask
......@@ -25,20 +26,18 @@ _CITATION = """
"""
class SATAnalogies(MultipleChoiceTask):
class SATAnalogies(MultipleChoiceTask):
VERSION = 0
NEEDS_MANUAL_DL = True
def __init__(self):
super().__init__()
DATASET_PATH = inspect.getfile(lm_eval.datasets.sat_analogies.sat_analogies)
DATASET_NAME = None
def download(self):
# We should be using a checksum here.
# The canonical sha256 hash is below:
# 9dece377d8d57253ef8c78370ff15de0bb1d9e90a82c815a67ba1e621e921bfc
if not os.path.exists('data/sat/SAT-package-V3.txt'):
raise NotImplementedError('SAT Analogies dataset is not provided. Follow instructions on https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art) to locate.')
def __init__(self, data_dir: str):
"""
SAT Analog Questions is not publicly available. You must request the data
by emailing Peter Turney and then download it to a local directory path
which should be passed into the `data_dir` arg.
"""
super().__init__(data_dir=data_dir)
def has_training_docs(self):
return False
......@@ -51,38 +50,20 @@ class SATAnalogies(MultipleChoiceTask):
def training_docs(self):
return []
def test_docs(self):
return []
def validation_docs(self):
data = []
return map(self._convert_standard, self.dataset["validation"])
with open("data/sat/SAT-package-V3.txt", "r") as f:
record = []
for line in f:
line = line.strip()
if len(line) == 0 and record:
data.append(record)
record = []
elif len(line) > 0 and line[0] == '#':
continue
else:
record.append(line)
data.append(record)
for record in data:
source = record[-8]
query = record[-7]
choices = record[-6:-1]
answer_key = record[-1]
def test_docs(self):
return []
doc = {
'source': source,
'query': query.split(' ')[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in choices],
'gold': ['a','b','c','d','e'].index(answer_key.strip()),
}
yield doc
def _convert_standard(self, doc):
return {
'source': doc['source'],
'query': doc['stem'].split(' ')[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]],
'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()),
}
def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query'])
......@@ -9,11 +9,7 @@ with supporting evidence for the correct answer is provided.
Homepage: https://allenai.org/data/sciq
"""
import os
import json
import zipfile
from lm_eval.base import MultipleChoiceTask
from best_download import download_file
_CITATION = """
......@@ -28,17 +24,8 @@ _CITATION = """
class SciQ(MultipleChoiceTask):
VERSION = 0
# Multiple languages and multiple years
def download(self):
if not os.path.exists('data/sciq'):
os.makedirs('data/sciq', exist_ok=True)
download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
local_file='data/sciq/SciQ.zip',
expected_checksum='7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
DATASET_PATH = "sciq"
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -49,36 +36,32 @@ class SciQ(MultipleChoiceTask):
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return map(self._convert_standard, self._training_docs)
def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"])
def test_docs(self):
return map(self._convert_standard, self.dataset["test"])
def _convert_standard(self, doc):
choices = [
doc["distractor1"],
doc["distractor2"],
doc["distractor1"],
doc["distractor2"],
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
out_doc = {
"source" : src,
"query" : doc['question'],
"choices" : choices,
"gold" : 3,
"source": src,
"query": doc['question'],
"choices": choices,
"gold": 3,
}
return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
......@@ -15,9 +15,7 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
"""
import datasets
from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask
from lm_eval.base import rf, Task
from functools import partial
from packaging import version
......@@ -45,7 +43,7 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask):
class SQuAD2(Task):
VERSION = 1
DATASET_PATH = "squad_v2"
DATASET_NAME = None
......@@ -63,10 +61,10 @@ class SQuAD2(HFTask):
return False
def training_docs(self):
return self.data["train"]
return self.dataset["train"]
def validation_docs(self):
return self.data["validation"]
return self.dataset["validation"]
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......
......@@ -8,8 +8,9 @@ to choose the correct ending to a four-sentence story.
Homepage: https://cs.rochester.edu/nlp/rocstories/
"""
import csv
from lm_eval.base import Task
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -34,11 +35,16 @@ _CITATION = """
class StoryCloze(Task):
VERSION = 0
NEEDS_MANUAL_DL = True
DATASET_PATH = "story_cloze"
DATASET_NAME = None
def download(self):
#TODO: replace with Eye link
pass
def __init__(self, data_dir: str):
"""
StoryCloze is not publicly available. You must download the data by
following https://cs.rochester.edu/nlp/rocstories/ and pass the folder
path into the `data_dir` arg.
"""
super().__init__(data_dir=data_dir)
def has_training_docs(self):
return False
......@@ -52,40 +58,46 @@ class StoryCloze(Task):
def training_docs(self):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return list(filereader)
def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
return self.dataset["validation"]
def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
return self.dataset["test"]
def doc_to_text(self, doc):
return ' '.join([*doc[1:5]])
return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def doc_to_target(self, doc):
return " " + doc[int(doc[-1]) - 4]
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
# `- 1` because the `answer_right_ending` index is 1-based.
return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
return lls
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -93,23 +105,36 @@ class StoryCloze(Task):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
gold = doc["answer_right_ending"] - 1
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": True
}
class StoryCloze2016(StoryCloze):
DATASET_NAME = "2016"
class StoryCloze2018(StoryCloze):
DATASET_NAME = "2018"
......@@ -12,10 +12,9 @@ TODO: WSC requires free-form generation.
import numpy as np
import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics
from . common import HFTask, yesno
from lm_eval.base import rf
from ..metrics import mean, acc_all, metric_max_over_ground_truths
from ..utils import general_detokenize
from lm_eval.base import rf, Task
from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
from lm_eval.utils import general_detokenize
_CITATION = """
......@@ -33,7 +32,7 @@ _CITATION = """
"""
class BoolQ(HFTask):
class BoolQ(Task):
VERSION = 1
DATASET_PATH = "super_glue"
DATASET_NAME = "boolq"
......@@ -47,6 +46,14 @@ class BoolQ(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
......@@ -81,7 +88,7 @@ class BoolQ(HFTask):
}
class CommitmentBank(HFTask):
class CommitmentBank(Task):
VERSION = 1
DATASET_PATH = "super_glue"
DATASET_NAME = "cb"
......@@ -95,6 +102,14 @@ class CommitmentBank(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"],
......@@ -148,7 +163,7 @@ class CommitmentBank(HFTask):
}
class Copa(HFTask):
class Copa(Task):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "copa"
......@@ -162,6 +177,14 @@ class Copa(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
# Drop the period
connector = {
......@@ -208,7 +231,7 @@ class Copa(HFTask):
return choice[0].lower() + choice[1:]
class MultiRC(HFTask):
class MultiRC(Task):
VERSION = 1
DATASET_PATH = "super_glue"
DATASET_NAME = "multirc"
......@@ -222,6 +245,14 @@ class MultiRC(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
......@@ -260,7 +291,7 @@ class MultiRC(HFTask):
}
class ReCoRD(HFTask):
class ReCoRD(Task):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "record"
......@@ -279,13 +310,13 @@ class ReCoRD(HFTask):
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
if self._training_docs is None:
self._training_docs = []
for doc in self.data["train"]:
for doc in self.dataset["train"]:
self._training_docs.append(self._process_doc(doc))
return self._training_docs
def validation_docs(self):
# See: training_docs
for doc in self.data["validation"]:
for doc in self.dataset["validation"]:
yield self._process_doc(doc)
@classmethod
......@@ -349,7 +380,7 @@ class ReCoRD(HFTask):
}
class WordsInContext(HFTask):
class WordsInContext(Task):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "wic"
......@@ -363,6 +394,14 @@ class WordsInContext(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format(
......@@ -401,7 +440,7 @@ class WordsInContext(HFTask):
}
class SGWinogradSchemaChallenge(HFTask):
class SGWinogradSchemaChallenge(Task):
VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task.
......@@ -423,11 +462,14 @@ class SGWinogradSchemaChallenge(HFTask):
# GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [
doc for doc in
self.data["train"]
self.dataset["train"]
if doc["label"]
]
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
......
......@@ -9,13 +9,10 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import os
import json
import jsonlines
import inspect
import lm_eval.datasets.triviaqa.triviaqa
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
from best_download import download_file
from lm_eval.metrics import mean
_CITATION = """
......@@ -33,14 +30,8 @@ _CITATION = """
class TriviaQA(Task):
VERSION = 0
def download(self):
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", local_file="data/triviaqa/triviaqa-unfiltered.tar.gz", expected_checksum="adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh("""
cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz
""")
DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -52,19 +43,19 @@ class TriviaQA(Task):
return False
def training_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
return self.dataset['train']
def validation_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
return self.dataset['validation']
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:"
return f"Question: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + doc['Answer']['Value']
return " " + doc['answer']['value']
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
......@@ -74,12 +65,11 @@ class TriviaQA(Task):
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']):
for alias in self._remove_prefixes(doc['answer']['aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
......
......@@ -19,16 +19,14 @@ we could try this?
Homepage: https://github.com/sylinrl/TruthfulQA
"""
import csv
import json
import inspect
import numpy as np
import sacrebleu
import datasets
import lm_eval.datasets.truthfulqa.truthfulqa
from rouge_score import rouge_scorer, scoring
from lm_eval.base import rf, Task
from pathlib import Path
from best_download import download_file
from ..metrics import mean
from datasets import load_metric
from lm_eval.metrics import mean
_CITATION = """
......@@ -62,15 +60,8 @@ QA_PROMPT = (
class TruthfulQAMultipleChoice(Task):
VERSION = 1
DATASET_PATH = Path('data/truthfulqa/mc')
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
mc_url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
checksum = "6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954"
download_file(mc_url, local_file=str(self.DATASET_PATH / "mc_task.json"), expected_checksum=checksum)
DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "multiple_choice"
def has_training_docs(self):
return False
......@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task):
raise NotImplementedError()
def validation_docs(self):
with open(self.DATASET_PATH / "mc_task.json") as f:
return json.load(f)
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
......@@ -121,7 +111,7 @@ class TruthfulQAMultipleChoice(Task):
return [rf.loglikelihood(ctx, " " + t)[0] for t in targets]
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
return get_lls(doc['mc1_targets']) + get_lls(doc['mc2_targets'])
return get_lls(doc['mc1_targets']["choices"]) + get_lls(doc['mc2_targets']["choices"])
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -139,14 +129,14 @@ class TruthfulQAMultipleChoice(Task):
def mc2(lls):
# Split on the first `0` as everything before it is true (`1`).
split_idx = list(doc['mc2_targets'].values()).index(0)
split_idx = list(doc['mc2_targets']["labels"]).index(0)
# Compute the normalized probability mass for the correct answer.
ll_true, ll_false = lls[:split_idx], lls[split_idx:]
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false))
return sum(p_true)
split_idx = len(doc['mc1_targets'])
split_idx = len(doc['mc1_targets']["choices"])
mc1_lls, mc2_lls = results[:split_idx], results[split_idx:]
return {
"mc1": mc1(mc1_lls),
......@@ -168,19 +158,12 @@ class TruthfulQAMultipleChoice(Task):
class TruthfulQAGeneration(Task):
VERSION = 1
DATASET_PATH = Path('data/truthfulqa/generation')
DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "generation"
def __init__(self):
super().__init__()
self.bleurt = load_metric("bleurt", cache_dir="lm_cache")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
checksum = "8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2"
download_file(url, local_file=str(self.DATASET_PATH / "TruthfulQA.csv"), expected_checksum=checksum)
self.bleurt = datasets.load_metric("bleurt")
def has_training_docs(self):
return False
......@@ -194,36 +177,29 @@ class TruthfulQAGeneration(Task):
def training_docs(self):
raise NotImplementedError()
def _split_multi_answer(self, answers, sep=';'):
answers = answers.strip().split(sep)
split_answers = []
def _format_answers(self, answers):
formatted_answers = []
for answer in answers:
answer = answer.strip()
if len(answer):
# Add a period after all answers.
if answer[-1] != '.':
split_answers.append(answer + '.')
formatted_answers.append(answer + '.')
else:
split_answers.append(answer)
return split_answers
formatted_answers.append(answer)
return formatted_answers
def validation_docs(self):
with open(self.DATASET_PATH / "TruthfulQA.csv", newline='') as csvfile:
doc_reader = csv.DictReader(csvfile)
for doc in doc_reader:
# Ensure that references exist.
if not doc['Correct Answers'] or not doc['Incorrect Answers']:
continue
correct_answers = self._split_multi_answer(doc['Correct Answers'])
if "I have no comment." not in correct_answers:
correct_answers.append("I have no comment.")
incorrect_answers = self._split_multi_answer(doc['Incorrect Answers'])
doc = {
'question': doc['Question'].strip(),
'correct_answers': correct_answers,
'incorrect_answers': incorrect_answers
}
yield doc
for doc in self.dataset["validation"]:
incorrect_answers = self._format_answers(doc['incorrect_answers'])
correct_answers = self._format_answers(doc['correct_answers'])
if "I have no comment." not in correct_answers:
correct_answers.append("I have no comment.")
yield {
'question': doc['question'].strip(),
'correct_answers': correct_answers,
'incorrect_answers': incorrect_answers
}
def test_docs(self):
raise NotImplementedError()
......
......@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
import gzip
import json
import shutil
from pathlib import Path
from best_download import download_file
import inspect
import lm_eval.datasets.unscramble.unscramble
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
......@@ -32,30 +29,10 @@ _CITATION = """
"""
def extract_gzip(gz, to):
with gzip.open(gz, 'rb') as fin:
with open(to, 'wb') as fout:
shutil.copyfileobj(fin, fout)
class WordUnscrambleTask(Task):
VERSION = 0
BASE_PATH = Path("data/unscramble")
FILENAME = None
CHECKSUM = None # SHA256 Checksum.
def __init__(self):
super().__init__()
def download(self):
if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME
if not file.exists():
rawfile = file.parent / (file.name + ".gz")
base_url = "https://raw.githubusercontent.com/openai/gpt-3/master/data"
download_file(f"{base_url}/{self.FILENAME}.gz", local_file=str(rawfile), expected_checksum=self.CHECKSUM)
extract_gzip(gz=rawfile, to=file)
DATASET_PATH = inspect.getfile(lm_eval.datasets.unscramble.unscramble)
DATASET_NAME = None
def has_training_docs(self):
return False
......@@ -67,8 +44,7 @@ class WordUnscrambleTask(Task):
return False
def validation_docs(self):
file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines())
return self.dataset["validation"]
def doc_to_text(self, doc):
return doc["context"]
......@@ -99,25 +75,20 @@ class WordUnscrambleTask(Task):
class Anagrams1(WordUnscrambleTask):
FILENAME = "mid_word_1_anagrams.jsonl"
CHECKSUM = "6768a86896083199de4815d4964cb2f6f1046476cfd80c2a562784f182905979"
DATASET_NAME = "mid_word_1_anagrams"
class Anagrams2(WordUnscrambleTask):
FILENAME = "mid_word_2_anagrams.jsonl"
CHECKSUM = "c3d839d09a7954b78a27cd2cd75d4ed0488656c56ef4dbd741a005343826cb01"
DATASET_NAME = "mid_word_2_anagrams"
class CycleLetters(WordUnscrambleTask):
FILENAME = "cycle_letters_in_word.jsonl"
CHECKSUM = "1689c9002bb8c5988bf5f05e977c9db92f57932c1b5a38998c29ac0dd71e1d42"
DATASET_NAME = "cycle_letters_in_word"
class RandomInsertion(WordUnscrambleTask):
FILENAME = "random_insertion_in_word.jsonl"
CHECKSUM = "72e65d83da53d15752ee0c47379509de149ddbad32d61184e5991df29616b78a"
DATASET_NAME = "random_insertion_in_word"
class ReversedWords(WordUnscrambleTask):
FILENAME = "reversed_words.jsonl"
CHECKSUM = "133a08f875cd6c1ef8608a3233571a773881cc27b1c707de738cc6543439332a"
DATASET_NAME = "reversed_words"
......@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013).
Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a
"""
from . common import HFTask
from lm_eval.base import rf
from ..metrics import mean
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -32,7 +31,7 @@ _CITATION = """
"""
class WebQs(HFTask):
class WebQs(Task):
VERSION = 0
DATASET_PATH = "web_questions"
DATASET_NAME = None
......@@ -46,6 +45,14 @@ class WebQs(HFTask):
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment