".circleci/unittest/vscode:/vscode.git/clone" did not exist on "70bb4920343e733921f036bc813e0f100578b86c"
Unverified Commit 6caa0afd authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #300 from jon-tow/hf-dataset-refactor

Refactor `Task` downloading to use `HuggingFace.datasets`
parents 7064d6b9 9434722c
...@@ -10,7 +10,6 @@ Homepage: https://math-qa.github.io/math-QA/ ...@@ -10,7 +10,6 @@ Homepage: https://math-qa.github.io/math-QA/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -25,7 +24,7 @@ _CITATION = """ ...@@ -25,7 +24,7 @@ _CITATION = """
""" """
class MathQA(HFTask, MultipleChoiceTask): class MathQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "math_qa" DATASET_PATH = "math_qa"
DATASET_NAME = None DATASET_NAME = None
...@@ -39,13 +38,23 @@ class MathQA(HFTask, MultipleChoiceTask): ...@@ -39,13 +38,23 @@ class MathQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct']) answer_idx = ['a', 'b', 'c', 'd', 'e'].index(doc['correct'])
choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])] choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc['options'])]
out_doc = { out_doc = {
"query": "Question: " + doc['Problem'] +"\nAnswer:", "query": "Question: " + doc['Problem'] + "\nAnswer:",
"choices": choices, "choices": choices,
"gold": answer_idx, "gold": answer_idx,
} }
......
...@@ -20,9 +20,8 @@ of a question's options. See section 4 of the paper for details. ...@@ -20,9 +20,8 @@ of a question's options. See section 4 of the paper for details.
Homepage: https://leaderboard.allenai.org/mctaco/submissions/public Homepage: https://leaderboard.allenai.org/mctaco/submissions/public
""" """
import numpy as np import numpy as np
from lm_eval.base import rf
from collections import defaultdict from collections import defaultdict
from . common import HFTask from lm_eval.base import rf, Task
_CITATION = """ _CITATION = """
...@@ -35,7 +34,7 @@ _CITATION = """ ...@@ -35,7 +34,7 @@ _CITATION = """
""" """
class MCTACO(HFTask): class MCTACO(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "mc_taco" DATASET_PATH = "mc_taco"
DATASET_NAME = None DATASET_NAME = None
...@@ -49,6 +48,12 @@ class MCTACO(HFTask): ...@@ -49,6 +48,12 @@ class MCTACO(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['sentence']}\nQuestion: {doc['question']}\n"\ return f"{doc['sentence']}\nQuestion: {doc['question']}\n"\
f"Answer: {doc['answer']}\nPlausible:" f"Answer: {doc['answer']}\nPlausible:"
......
...@@ -7,14 +7,11 @@ modified from Chinese high school English listening comprehension test data. ...@@ -7,14 +7,11 @@ modified from Chinese high school English listening comprehension test data.
Homepage: https://github.com/Nealcly/MuTual Homepage: https://github.com/Nealcly/MuTual
""" """
import json
import zipfile
import shutil
import numpy as np import numpy as np
from pathlib import Path import inspect
import lm_eval.datasets.mutual.mutual
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -30,29 +27,10 @@ _CITATION = """ ...@@ -30,29 +27,10 @@ _CITATION = """
class MuTualBase(Task): class MuTualBase(Task):
VERSION = 1 VERSION = 1
BASE_PATH = Path("data/mutual") DATASET_PATH = inspect.getfile(lm_eval.datasets.mutual.mutual)
DATASET_NAME = None DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D'] CHOICES = ['A', 'B', 'C', 'D']
def __init__(self):
super().__init__()
def download(self):
if self.BASE_PATH.exists():
return
Path.mkdir(self.BASE_PATH, parents=True)
master_zip = Path("data/master.zip")
download_file(
"https://github.com/Nealcly/MuTual/archive/master.zip",
local_file=str(master_zip),
expected_checksum="bb325cf6c672f0f02699993a37138b0fa0af6fcfc77ec81dfbe46add4d7b29f9")
with zipfile.ZipFile(master_zip, 'r') as zip:
zip.extractall("data")
Path("data/MuTual-master/data").rename(str(self.BASE_PATH))
# Remove left over files and directories.
master_zip.unlink()
shutil.rmtree("data/MuTual-master")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -62,18 +40,11 @@ class MuTualBase(Task): ...@@ -62,18 +40,11 @@ class MuTualBase(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _load_docs(self, path):
for file in sorted(path.iterdir()):
if file.suffix != ".txt":
continue
with open(file, 'r', encoding='utf-8') as f:
yield json.load(f)
def training_docs(self): def training_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "train") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self._load_docs(self.BASE_PATH / self.DATASET_NAME / "dev") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
...@@ -134,8 +105,8 @@ class MuTualBase(Task): ...@@ -134,8 +105,8 @@ class MuTualBase(Task):
class MuTual(MuTualBase): class MuTual(MuTualBase):
DATASET_NAME = Path("mutual") DATASET_NAME = "mutual"
class MuTualPlus(MuTualBase): class MuTualPlus(MuTualBase):
DATASET_NAME = Path("mutual_plus") DATASET_NAME = "mutual_plus"
...@@ -15,8 +15,7 @@ not even bother with the train set. ...@@ -15,8 +15,7 @@ not even bother with the train set.
Homepage: https://ai.google.com/research/NaturalQuestions Homepage: https://ai.google.com/research/NaturalQuestions
""" """
import random from lm_eval.base import Task
from . common import HFTask
from itertools import islice from itertools import islice
...@@ -30,7 +29,7 @@ _CITATION = """ ...@@ -30,7 +29,7 @@ _CITATION = """
""" """
class NaturalQs(HFTask): class NaturalQs(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "natural_questions" DATASET_PATH = "natural_questions"
DATASET_NAME = None DATASET_NAME = None
...@@ -47,7 +46,12 @@ class NaturalQs(HFTask): ...@@ -47,7 +46,12 @@ class NaturalQs(HFTask):
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# Data is too large to fit in memory. # Data is too large to fit in memory.
return self.data["train"] if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit. # Data is too large to fit in memory. We just sample from the first bit.
......
...@@ -15,7 +15,6 @@ based algorithm and a word co-occurrence algorithm. ...@@ -15,7 +15,6 @@ based algorithm and a word co-occurrence algorithm.
Homepage: https://allenai.org/data/open-book-qa Homepage: https://allenai.org/data/open-book-qa
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class OpenBookQA(HFTask, MultipleChoiceTask): class OpenBookQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "openbookqa" DATASET_PATH = "openbookqa"
DATASET_NAME = "main" DATASET_NAME = "main"
...@@ -42,7 +41,18 @@ class OpenBookQA(HFTask, MultipleChoiceTask): ...@@ -42,7 +41,18 @@ class OpenBookQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["id"], "id": doc["id"],
"query": doc["question_stem"], "query": doc["question_stem"],
......
...@@ -10,15 +10,9 @@ math, computer science, and philosophy papers. ...@@ -10,15 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/ Homepage: https://pile.eleuther.ai/
""" """
import os import inspect
import lm_eval.datasets.pile.pile
import lm_dataformat from lm_eval.base import PerplexityTask
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -31,32 +25,10 @@ _CITATION = """ ...@@ -31,32 +25,10 @@ _CITATION = """
""" """
class PilePerplexityTask(PerplexityTask, abc.ABC): class PilePerplexityTask(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.pile.pile)
PILE_SET_NAME = None DATASET_NAME = None
VAL_PATH = 'data/pile/val.jsonl.zst'
TEST_PATH = 'data/pile/test.jsonl.zst'
def download(self):
# TODO: separate pile val/test out by component so we don't have to scan the entire file once per set
if not os.path.exists("data/pile/test.jsonl.zst"):
# todo use new best_download fallback api
os.makedirs("data/pile/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/pile/val.jsonl.zst", local_file=self.VAL_PATH, expected_checksum="264c875d8bbd355d8daa9d032b75fd8fb91606218bb84dd1155b203fcd5fab92")
download_file("http://eaidata.bmk.sh/data/pile/test.jsonl.zst", local_file=self.TEST_PATH, expected_checksum="0bb28c52d0b5596d389bf179ce2d43bf7f7ffae76b0d2d20b180c97f62e0975e")
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def test_docs(self):
rdr = lm_dataformat.Reader(self.TEST_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -64,90 +36,98 @@ class PilePerplexityTask(PerplexityTask, abc.ABC): ...@@ -64,90 +36,98 @@ class PilePerplexityTask(PerplexityTask, abc.ABC):
def has_test_docs(self): def has_test_docs(self):
return True return True
def validation_docs(self):
for doc in self.dataset["validation"]:
yield doc["text"]
def test_docs(self):
for doc in self.dataset["test"]:
yield doc["text"]
class PileArxiv(PilePerplexityTask): class PileArxiv(PilePerplexityTask):
PILE_SET_NAME = "ArXiv" DATASET_NAME = "pile_arxiv"
class PileBooks3(PilePerplexityTask): class PileBooks3(PilePerplexityTask):
PILE_SET_NAME = "Books3" DATASET_NAME = "pile_books3"
class PileBookCorpus2(PilePerplexityTask): class PileBookCorpus2(PilePerplexityTask):
PILE_SET_NAME = "BookCorpus2" DATASET_NAME = "pile_bookcorpus2"
class PileDmMathematics(PilePerplexityTask): class PileDmMathematics(PilePerplexityTask):
PILE_SET_NAME = "DM Mathematics" DATASET_NAME = "pile_dm-mathematics"
class PileEnron(PilePerplexityTask): class PileEnron(PilePerplexityTask):
PILE_SET_NAME = "Enron Emails" DATASET_NAME = "pile_enron"
class PileEuroparl(PilePerplexityTask): class PileEuroparl(PilePerplexityTask):
PILE_SET_NAME = "EuroParl" DATASET_NAME = "pile_europarl"
class PileFreeLaw(PilePerplexityTask): class PileFreeLaw(PilePerplexityTask):
PILE_SET_NAME = "FreeLaw" DATASET_NAME = "pile_freelaw"
class PileGithub(PilePerplexityTask): class PileGithub(PilePerplexityTask):
PILE_SET_NAME = "Github" DATASET_NAME = "pile_github"
class PileGutenberg(PilePerplexityTask): class PileGutenberg(PilePerplexityTask):
PILE_SET_NAME = "Gutenberg (PG-19)" DATASET_NAME = "pile_gutenberg"
class PileHackernews(PilePerplexityTask): class PileHackernews(PilePerplexityTask):
PILE_SET_NAME = "HackerNews" DATASET_NAME = "pile_hackernews"
class PileNIHExporter(PilePerplexityTask): class PileNIHExporter(PilePerplexityTask):
PILE_SET_NAME = "NIH ExPorter" DATASET_NAME = "pile_nih-exporter"
class PileOpenSubtitles(PilePerplexityTask): class PileOpenSubtitles(PilePerplexityTask):
PILE_SET_NAME = "OpenSubtitles" DATASET_NAME = "pile_opensubtitles"
class PileOpenWebText2(PilePerplexityTask): class PileOpenWebText2(PilePerplexityTask):
PILE_SET_NAME = "OpenWebText2" DATASET_NAME = "pile_openwebtext2"
class PilePhilPapers(PilePerplexityTask): class PilePhilPapers(PilePerplexityTask):
PILE_SET_NAME = "PhilPapers" DATASET_NAME = "pile_philpapers"
class PilePileCc(PilePerplexityTask): class PilePileCc(PilePerplexityTask):
PILE_SET_NAME = "Pile-CC" DATASET_NAME = "pile_pile-cc"
class PilePubmedAbstracts(PilePerplexityTask): class PilePubmedAbstracts(PilePerplexityTask):
PILE_SET_NAME = "PubMed Abstracts" DATASET_NAME = "pile_pubmed-abstracts"
class PilePubmedCentral(PilePerplexityTask): class PilePubmedCentral(PilePerplexityTask):
PILE_SET_NAME = "PubMed Central" DATASET_NAME = "pile_pubmed-central"
class PileStackExchange(PilePerplexityTask): class PileStackExchange(PilePerplexityTask):
PILE_SET_NAME = "StackExchange" DATASET_NAME = "pile_stackexchange"
class PileUspto(PilePerplexityTask): class PileUspto(PilePerplexityTask):
PILE_SET_NAME = "USPTO Backgrounds" DATASET_NAME = "pile_upsto"
class PileUbuntuIrc(PilePerplexityTask): class PileUbuntuIrc(PilePerplexityTask):
PILE_SET_NAME = "Ubuntu IRC" DATASET_NAME = "pile_ubuntu-irc"
class PileWikipedia(PilePerplexityTask): class PileWikipedia(PilePerplexityTask):
PILE_SET_NAME = "Wikipedia (en)" DATASET_NAME = "pile_wikipedia"
class PileYoutubeSubtitles(PilePerplexityTask): class PileYoutubeSubtitles(PilePerplexityTask):
PILE_SET_NAME = "YoutubeSubtitles" DATASET_NAME = "pile_youtubesubtitles"
...@@ -9,10 +9,7 @@ actually learning about the world? ...@@ -9,10 +9,7 @@ actually learning about the world?
Homepage: https://yonatanbisk.com/piqa/ Homepage: https://yonatanbisk.com/piqa/
""" """
import numpy as np from lm_eval.base import MultipleChoiceTask
from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -29,7 +26,7 @@ _CITATION = """ ...@@ -29,7 +26,7 @@ _CITATION = """
""" """
class PiQA(HFTask, MultipleChoiceTask): class PiQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "piqa" DATASET_PATH = "piqa"
DATASET_NAME = None DATASET_NAME = None
...@@ -43,7 +40,15 @@ class PiQA(HFTask, MultipleChoiceTask): ...@@ -43,7 +40,15 @@ class PiQA(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"goal": doc["goal"], "goal": doc["goal"],
"choices": [doc["sol1"], doc["sol2"]], "choices": [doc["sol1"], doc["sol2"]],
......
...@@ -15,7 +15,6 @@ have been trained on data not specifically collected to succeed on PROST." ...@@ -15,7 +15,6 @@ have been trained on data not specifically collected to succeed on PROST."
Homepage: https://github.com/nala-cub/prost Homepage: https://github.com/nala-cub/prost
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -36,7 +35,7 @@ _CITATION = """ ...@@ -36,7 +35,7 @@ _CITATION = """
""" """
class PROST(HFTask, MultipleChoiceTask): class PROST(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "corypaik/prost" DATASET_PATH = "corypaik/prost"
DATASET_NAME = None DATASET_NAME = None
...@@ -50,6 +49,9 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -50,6 +49,9 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context( return super().fewshot_context(
...@@ -59,7 +61,7 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -59,7 +61,7 @@ class PROST(HFTask, MultipleChoiceTask):
description=description description=description
) )
def _convert_standard(self, doc): def _process_doc(self, doc):
out_doc = { out_doc = {
"query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:", "query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:",
"choices": [doc['A'], doc['B'], doc['C'], doc['D']], "choices": [doc['A'], doc['B'], doc['C'], doc['D']],
......
...@@ -16,9 +16,8 @@ and (4) a yes/no/maybe answer which summarizes the conclusion. ...@@ -16,9 +16,8 @@ and (4) a yes/no/maybe answer which summarizes the conclusion.
Homepage: https://pubmedqa.github.io/ Homepage: https://pubmedqa.github.io/
""" """
import numpy as np import numpy as np
from .common import HFTask from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
_CITATION = """ _CITATION = """
...@@ -32,7 +31,7 @@ _CITATION = """ ...@@ -32,7 +31,7 @@ _CITATION = """
""" """
class Pubmed_QA(HFTask): class Pubmed_QA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "pubmed_qa" DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled" DATASET_NAME = "pqa_labeled"
...@@ -49,7 +48,7 @@ class Pubmed_QA(HFTask): ...@@ -49,7 +48,7 @@ class Pubmed_QA(HFTask):
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
# HF is labelled as train but its really just for testing # HF is labelled as train but its really just for testing
return self.data["train"] return self.dataset["train"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
......
...@@ -13,9 +13,6 @@ and Entrance Exam. ...@@ -13,9 +13,6 @@ and Entrance Exam.
Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php
""" """
import os
import xml.etree.ElementTree as ET
from best_download import download_file
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -31,35 +28,8 @@ _CITATION = """ ...@@ -31,35 +28,8 @@ _CITATION = """
class QA4MRE(MultipleChoiceTask): class QA4MRE(MultipleChoiceTask):
VERSION = 0 VERSION = 0
YEAR = None DATASET_PATH = "qa4mre"
def download(self): DATASET_NAME = None
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.makedirs("data/qa4mre", exist_ok=True)
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
local_file=f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
expected_checksum=sha256sums[year],
)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -70,39 +40,31 @@ class QA4MRE(MultipleChoiceTask): ...@@ -70,39 +40,31 @@ class QA4MRE(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, question): def test_docs(self):
choices = [i.text for i in question.iter('answer')] # `qa4mre` only has train data so we use it for the test docs.
return map(self._process_doc, self.dataset["train"])
def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"]
out_doc = { out_doc = {
"query" : question.find('q_str').text, "source": doc["document_str"].strip().replace("\'", "'"),
"choices": choices, "query": doc["question_str"],
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1, "choices": choices,
"gold": int(doc["correct_answer_id"]) - 1,
} }
return out_doc return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]) return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
class QA4MRE_2011(QA4MRE): class QA4MRE_2011(QA4MRE):
YEAR = 2011 DATASET_NAME = "2011.main.EN"
class QA4MRE_2012(QA4MRE): class QA4MRE_2012(QA4MRE):
YEAR = 2012 DATASET_NAME = "2012.main.EN"
class QA4MRE_2013(QA4MRE): class QA4MRE_2013(QA4MRE):
YEAR = 2013 DATASET_NAME = "2013.main.EN"
...@@ -11,13 +11,10 @@ provide supporting evidence to answers. ...@@ -11,13 +11,10 @@ provide supporting evidence to answers.
Homepage: https://allenai.org/data/qasper Homepage: https://allenai.org/data/qasper
""" """
from collections import Counter from collections import Counter
from math import exp
import random
import re import re
import string import string
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import f1_score, mean from lm_eval.metrics import f1_score, mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -104,11 +101,20 @@ def token_f1_score(prediction, ground_truth): ...@@ -104,11 +101,20 @@ def token_f1_score(prediction, ground_truth):
return f1 return f1
class QASPER(HFTask): class QASPER(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "qasper" DATASET_PATH = "qasper"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ( return (
"TITLE: " "TITLE: "
...@@ -130,14 +136,14 @@ class QASPER(HFTask): ...@@ -130,14 +136,14 @@ class QASPER(HFTask):
return " " + answer return " " + answer
def training_docs(self): def training_docs(self):
for doc in self.data["train"]: for doc in self.dataset["train"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def validation_docs(self): def validation_docs(self):
for doc in self.data["train"]: for doc in self.dataset["validation"]:
yield from self.process_doc(doc) yield from self._process_doc(doc)
def process_doc(self, doc): def _process_doc(self, doc):
"""Given a `doc`, flatten it out so that each JSON blob """Given a `doc`, flatten it out so that each JSON blob
contains exactly one question and one answer. Logic taken from contains exactly one question and one answer. Logic taken from
the reference implementation available at the reference implementation available at
......
...@@ -10,10 +10,9 @@ a teacher who answers the questions by providing short excerpts (spans) from the ...@@ -10,10 +10,9 @@ a teacher who answers the questions by providing short excerpts (spans) from the
Homepage: https://quac.ai/ Homepage: https://quac.ai/
""" """
import json import inspect
import os import lm_eval.datasets.quac.quac
from lm_eval.base import Task from lm_eval.base import Task
from ..utils import sh
_CITATION = """ _CITATION = """
...@@ -28,18 +27,8 @@ _CITATION = """ ...@@ -28,18 +27,8 @@ _CITATION = """
class QuAC(Task): class QuAC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = inspect.getfile(lm_eval.datasets.quac.quac)
def __init__(self): DATASET_NAME = None
super().__init__()
def download(self):
if not os.path.exists('data/quac'):
# TODO: convert to use best_download
sh("""
mkdir -p data/quac
wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json
wget https://s3.amazonaws.com/my89public/quac/val_v0.2.json -O data/quac/val_v0.2.json
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -51,28 +40,20 @@ class QuAC(Task): ...@@ -51,28 +40,20 @@ class QuAC(Task):
return False return False
def training_docs(self): def training_docs(self):
myjson = json.load(open('data/quac/train_v0.2.json'))['data'] if self._training_docs is None:
return self.load_doc(myjson) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self): def validation_docs(self):
myjson = json.load(open('data/quac/val_v0.2.json'))['data'] return map(self._process_doc, self.dataset["validation"])
return self.load_doc(myjson)
def test_docs(self): def test_docs(self):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def load_doc(self, myjson): def _process_doc(self, doc):
docs = [] doc["title"] = doc['title'] + ' - ' + doc['section_title']
for item in myjson: return doc
title = item['title'] + ' - ' + item['section_title']
paragraph = item['paragraphs'][0]['context'].replace("CANNOTANSWER", "")
qas = item['paragraphs'][0]['qas']
qa_pairs = [(qa['question'], qa['answers'][0]['text']) for qa in qas]
for (question, answer) in qa_pairs:
doc = { 'title': title, 'paragraph': paragraph, 'question': question, 'answer': answer }
docs.append(doc)
return docs
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: '
...@@ -88,7 +69,7 @@ class QuAC(Task): ...@@ -88,7 +69,7 @@ class QuAC(Task):
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError('Evaluation not implemented')
......
...@@ -12,9 +12,8 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/ ...@@ -12,9 +12,8 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/
import collections import collections
import datasets import datasets
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean from lm_eval.metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -35,16 +34,14 @@ class each: ...@@ -35,16 +34,14 @@ class each:
return list(map(self.f, other)) return list(map(self.f, other))
class RACE(HFTask): class RACE(Task):
VERSION = 0 VERSION = 1
DATASET_PATH = "race" DATASET_PATH = "race"
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
assert datasets.__version__ == "1.15.1", "RACE requires datasets==1.15.1!"
def has_training_docs(self): def has_training_docs(self):
return True return True
......
...@@ -7,7 +7,8 @@ multiple-choice analogy questions; 5 choices per question. ...@@ -7,7 +7,8 @@ multiple-choice analogy questions; 5 choices per question.
Homepage: https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art) Homepage: https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art)
""" """
import os import inspect
import lm_eval.datasets.sat_analogies.sat_analogies
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -25,20 +26,18 @@ _CITATION = """ ...@@ -25,20 +26,18 @@ _CITATION = """
""" """
class SATAnalogies(MultipleChoiceTask): class SATAnalogies(MultipleChoiceTask):
VERSION = 0 VERSION = 0
NEEDS_MANUAL_DL = True DATASET_PATH = inspect.getfile(lm_eval.datasets.sat_analogies.sat_analogies)
DATASET_NAME = None
def __init__(self):
super().__init__()
def download(self): def __init__(self, data_dir: str):
# We should be using a checksum here. """
# The canonical sha256 hash is below: SAT Analog Questions is not publicly available. You must request the data
# 9dece377d8d57253ef8c78370ff15de0bb1d9e90a82c815a67ba1e621e921bfc by emailing Peter Turney and then download it to a local directory path
which should be passed into the `data_dir` arg.
if not os.path.exists('data/sat/SAT-package-V3.txt'): """
raise NotImplementedError('SAT Analogies dataset is not provided. Follow instructions on https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art) to locate.') super().__init__(data_dir=data_dir)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -51,38 +50,20 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -51,38 +50,20 @@ class SATAnalogies(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
return [] return []
def test_docs(self):
return []
def validation_docs(self): def validation_docs(self):
data = [] return map(self._process_doc, self.dataset["validation"])
with open("data/sat/SAT-package-V3.txt", "r") as f: def test_docs(self):
record = [] return []
for line in f:
line = line.strip()
if len(line) == 0 and record:
data.append(record)
record = []
elif len(line) > 0 and line[0] == '#':
continue
else:
record.append(line)
data.append(record)
for record in data:
source = record[-8]
query = record[-7]
choices = record[-6:-1]
answer_key = record[-1]
doc = { def _process_doc(self, doc):
'source': source, return {
'query': query.split(' ')[:2], 'source': doc['source'],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in choices], 'query': doc['stem'].split(' ')[:2],
'gold': ['a','b','c','d','e'].index(answer_key.strip()), 'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]],
} 'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()),
yield doc }
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
...@@ -9,11 +9,7 @@ with supporting evidence for the correct answer is provided. ...@@ -9,11 +9,7 @@ with supporting evidence for the correct answer is provided.
Homepage: https://allenai.org/data/sciq Homepage: https://allenai.org/data/sciq
""" """
import os
import json
import zipfile
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,17 +24,8 @@ _CITATION = """ ...@@ -28,17 +24,8 @@ _CITATION = """
class SciQ(MultipleChoiceTask): class SciQ(MultipleChoiceTask):
VERSION = 0 VERSION = 0
# Multiple languages and multiple years DATASET_PATH = "sciq"
def download(self): DATASET_NAME = None
if not os.path.exists('data/sciq'):
os.makedirs('data/sciq', exist_ok=True)
download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
local_file='data/sciq/SciQ.zip',
expected_checksum='7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -49,36 +36,32 @@ class SciQ(MultipleChoiceTask): ...@@ -49,36 +36,32 @@ class SciQ(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
choices = [ choices = [
doc["distractor1"], doc["distractor1"],
doc["distractor2"], doc["distractor2"],
doc["distractor3"], doc["distractor3"],
doc["correct_answer"], doc["correct_answer"],
] ]
src = doc['support'] src = doc['support']
out_doc = { out_doc = {
"source" : src, "source": src,
"query" : doc['question'], "query": doc['question'],
"choices" : choices, "choices": choices,
"gold" : 3, "gold": 3,
} }
return out_doc return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip() return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
...@@ -15,9 +15,7 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/ ...@@ -15,9 +15,7 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets import datasets
from math import exp from math import exp
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import f1_score, mean
from . common import HFTask
from functools import partial from functools import partial
from packaging import version from packaging import version
...@@ -45,7 +43,7 @@ def _squad_agg(key, items): ...@@ -45,7 +43,7 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references)[key] return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask): class SQuAD2(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
DATASET_NAME = None DATASET_NAME = None
...@@ -63,10 +61,10 @@ class SQuAD2(HFTask): ...@@ -63,10 +61,10 @@ class SQuAD2(HFTask):
return False return False
def training_docs(self): def training_docs(self):
return self.data["train"] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.data["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......
...@@ -8,8 +8,9 @@ to choose the correct ending to a four-sentence story. ...@@ -8,8 +8,9 @@ to choose the correct ending to a four-sentence story.
Homepage: https://cs.rochester.edu/nlp/rocstories/ Homepage: https://cs.rochester.edu/nlp/rocstories/
""" """
import csv import numpy as np
from lm_eval.base import Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
...@@ -34,11 +35,16 @@ _CITATION = """ ...@@ -34,11 +35,16 @@ _CITATION = """
class StoryCloze(Task): class StoryCloze(Task):
VERSION = 0 VERSION = 0
NEEDS_MANUAL_DL = True DATASET_PATH = "story_cloze"
DATASET_NAME = None
def download(self): def __init__(self, data_dir: str):
#TODO: replace with Eye link """
pass StoryCloze is not publicly available. You must download the data by
following https://cs.rochester.edu/nlp/rocstories/ and pass the folder
path into the `data_dir` arg.
"""
super().__init__(data_dir=data_dir)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -52,40 +58,46 @@ class StoryCloze(Task): ...@@ -52,40 +58,46 @@ class StoryCloze(Task):
def training_docs(self): def training_docs(self):
pass pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return list(filereader)
def validation_docs(self): def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv") return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([*doc[1:5]]) return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc[int(doc[-1]) - 4] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
# `- 1` because the `answer_right_ending` index is 1-based.
return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
raise NotImplementedError('Evaluation not implemented') lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
return lls
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -93,23 +105,36 @@ class StoryCloze(Task): ...@@ -93,23 +105,36 @@ class StoryCloze(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = doc["answer_right_ending"] - 1
raise NotImplementedError('Evaluation not implemented') acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
class StoryCloze2016(StoryCloze):
DATASET_NAME = "2016"
class StoryCloze2018(StoryCloze):
DATASET_NAME = "2018"
...@@ -12,10 +12,9 @@ TODO: WSC requires free-form generation. ...@@ -12,10 +12,9 @@ TODO: WSC requires free-form generation.
import numpy as np import numpy as np
import sklearn import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from . common import HFTask, yesno from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
from ..metrics import mean, acc_all, metric_max_over_ground_truths from lm_eval.utils import general_detokenize
from ..utils import general_detokenize
_CITATION = """ _CITATION = """
...@@ -33,7 +32,7 @@ _CITATION = """ ...@@ -33,7 +32,7 @@ _CITATION = """
""" """
class BoolQ(HFTask): class BoolQ(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -47,6 +46,14 @@ class BoolQ(HFTask): ...@@ -47,6 +46,14 @@ class BoolQ(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
...@@ -81,7 +88,7 @@ class BoolQ(HFTask): ...@@ -81,7 +88,7 @@ class BoolQ(HFTask):
} }
class CommitmentBank(HFTask): class CommitmentBank(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -95,6 +102,14 @@ class CommitmentBank(HFTask): ...@@ -95,6 +102,14 @@ class CommitmentBank(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
...@@ -148,7 +163,7 @@ class CommitmentBank(HFTask): ...@@ -148,7 +163,7 @@ class CommitmentBank(HFTask):
} }
class Copa(HFTask): class Copa(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "copa" DATASET_NAME = "copa"
...@@ -162,6 +177,14 @@ class Copa(HFTask): ...@@ -162,6 +177,14 @@ class Copa(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Drop the period # Drop the period
connector = { connector = {
...@@ -208,7 +231,7 @@ class Copa(HFTask): ...@@ -208,7 +231,7 @@ class Copa(HFTask):
return choice[0].lower() + choice[1:] return choice[0].lower() + choice[1:]
class MultiRC(HFTask): class MultiRC(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -222,6 +245,14 @@ class MultiRC(HFTask): ...@@ -222,6 +245,14 @@ class MultiRC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
...@@ -260,7 +291,7 @@ class MultiRC(HFTask): ...@@ -260,7 +291,7 @@ class MultiRC(HFTask):
} }
class ReCoRD(HFTask): class ReCoRD(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "record" DATASET_NAME = "record"
...@@ -279,13 +310,13 @@ class ReCoRD(HFTask): ...@@ -279,13 +310,13 @@ class ReCoRD(HFTask):
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
if self._training_docs is None: if self._training_docs is None:
self._training_docs = [] self._training_docs = []
for doc in self.data["train"]: for doc in self.dataset["train"]:
self._training_docs.append(self._process_doc(doc)) self._training_docs.append(self._process_doc(doc))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
# See: training_docs # See: training_docs
for doc in self.data["validation"]: for doc in self.dataset["validation"]:
yield self._process_doc(doc) yield self._process_doc(doc)
@classmethod @classmethod
...@@ -349,7 +380,7 @@ class ReCoRD(HFTask): ...@@ -349,7 +380,7 @@ class ReCoRD(HFTask):
} }
class WordsInContext(HFTask): class WordsInContext(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wic" DATASET_NAME = "wic"
...@@ -363,6 +394,14 @@ class WordsInContext(HFTask): ...@@ -363,6 +394,14 @@ class WordsInContext(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format( " two sentences above?\nAnswer:".format(
...@@ -401,7 +440,7 @@ class WordsInContext(HFTask): ...@@ -401,7 +440,7 @@ class WordsInContext(HFTask):
} }
class SGWinogradSchemaChallenge(HFTask): class SGWinogradSchemaChallenge(Task):
VERSION = 0 VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task. # binary version of the task.
...@@ -423,11 +462,14 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -423,11 +462,14 @@ class SGWinogradSchemaChallenge(HFTask):
# GPT-3 Paper's format only uses positive examples for fewshot "training" # GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [ self._training_docs = [
doc for doc in doc for doc in
self.data["train"] self.dataset["train"]
if doc["label"] if doc["label"]
] ]
return self._training_docs return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
......
...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task): ...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task):
super().__init__() super().__init__()
def download(self): def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically # This caches in the users home dir automatically
self.src_file, self.ref_file = \ self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair) sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
......
...@@ -9,13 +9,10 @@ high quality distant supervision for answering the questions. ...@@ -9,13 +9,10 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import os import inspect
import json import lm_eval.datasets.triviaqa.triviaqa
import jsonlines
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from lm_eval.metrics import mean
from ..utils import sh
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -33,14 +30,8 @@ _CITATION = """ ...@@ -33,14 +30,8 @@ _CITATION = """
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 0 VERSION = 0
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'): DATASET_NAME = None
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", local_file="data/triviaqa/triviaqa-unfiltered.tar.gz", expected_checksum="adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh("""
cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -52,19 +43,19 @@ class TriviaQA(Task): ...@@ -52,19 +43,19 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl') return self.dataset['train']
def validation_docs(self): def validation_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl') return self.dataset['validation']
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:" return f"Question: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['Answer']['Value'] return " " + doc['answer']['value']
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
...@@ -74,12 +65,11 @@ class TriviaQA(Task): ...@@ -74,12 +65,11 @@ class TriviaQA(Task):
for alias in aliases[1:]: for alias in aliases[1:]:
if not alias.startswith(ret[-1]): if not alias.startswith(ret[-1]):
ret.append(alias) ret.append(alias)
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']): for alias in self._remove_prefixes(doc['answer']['aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment