Unverified Commit 6caa0afd authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #300 from jon-tow/hf-dataset-refactor

Refactor `Task` downloading to use `HuggingFace.datasets`
parents 7064d6b9 9434722c
...@@ -10,9 +10,8 @@ provided explanations. ...@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli" Homepage: "https://github.com/facebookresearch/anli"
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean from lm_eval.metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -31,7 +30,7 @@ _CITATION = """ ...@@ -31,7 +30,7 @@ _CITATION = """
""" """
class ANLIBase(HFTask): class ANLIBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
...@@ -49,16 +48,16 @@ class ANLIBase(HFTask): ...@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def training_docs(self): def training_docs(self):
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.data["train_r" + str(self.SPLIT)]) self._training_docs = list(self.dataset["train_r" + str(self.SPLIT)])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["dev_r" + str(self.SPLIT)] return self.dataset["dev_r" + str(self.SPLIT)]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_r" + str(self.SPLIT)] return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
...@@ -125,11 +124,14 @@ class ANLIBase(HFTask): ...@@ -125,11 +124,14 @@ class ANLIBase(HFTask):
"acc": True "acc": True
} }
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
SPLIT = 1 SPLIT = 1
class ANLIRound2(ANLIBase): class ANLIRound2(ANLIBase):
SPLIT = 2 SPLIT = 2
class ANLIRound3(ANLIBase): class ANLIRound3(ANLIBase):
SPLIT = 3 SPLIT = 3
...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -27,7 +26,7 @@ _CITATION = """ ...@@ -27,7 +26,7 @@ _CITATION = """
""" """
class ARCEasy(HFTask, MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one # NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters. # of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
......
...@@ -7,13 +7,10 @@ problem in natural language. ...@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data Homepage: https://github.com/openai/gpt-3/tree/master/data
""" """
import abc import inspect
import json import lm_eval.datasets.arithmetic.arithmetic
import os
from collections import namedtuple
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -31,33 +28,9 @@ _CITATION = """ ...@@ -31,33 +28,9 @@ _CITATION = """
""" """
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task): class Arithmetic(Task):
VERSION = 0 VERSION = 0
directory = 'data/arithmetic/' DATASET_PATH = inspect.getfile(lm_eval.datasets.arithmetic.arithmetic)
def __init__(self):
super().__init__()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
self.set_docs()
@abc.abstractmethod
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -72,25 +45,19 @@ class Arithmetic(Task): ...@@ -72,25 +45,19 @@ class Arithmetic(Task):
return NotImplemented return NotImplemented
def validation_docs(self): def validation_docs(self):
return self._docs return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc.context return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc.completion return doc["completion"]
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion) ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -111,41 +78,40 @@ class Arithmetic(Task): ...@@ -111,41 +78,40 @@ class Arithmetic(Task):
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2da"
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
class Arithmetic2DMinus(Arithmetic): class Arithmetic2DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2ds"
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
class Arithmetic3DPlus(Arithmetic): class Arithmetic3DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3da"
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
class Arithmetic3DMinus(Arithmetic): class Arithmetic3DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3ds"
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
class Arithmetic4DPlus(Arithmetic): class Arithmetic4DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4da"
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
class Arithmetic4DMinus(Arithmetic): class Arithmetic4DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4ds"
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
class Arithmetic5DPlus(Arithmetic): class Arithmetic5DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5da"
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
class Arithmetic5DMinus(Arithmetic): class Arithmetic5DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5ds"
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
class Arithmetic2DMultiplication(Arithmetic): class Arithmetic2DMultiplication(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2dm"
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
class Arithmetic1DComposite(Arithmetic): class Arithmetic1DComposite(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_1dc"
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation. ...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset Homepage: https://github.com/chaochun/nlu-asdiv-dataset
""" """
from lm_eval.base import Task import inspect
from pathlib import Path import lm_eval.datasets.asdiv.asdiv
from best_download import download_file from lm_eval.base import rf, Task
import xml.etree.ElementTree as ET from lm_eval.metrics import mean
from lm_eval.base import rf
from lm_eval.metrics import mean,perplexity
import numpy as np
from zipfile import ZipFile
import os
_CITATION = """ _CITATION = """
...@@ -39,39 +34,11 @@ _CITATION = """ ...@@ -39,39 +34,11 @@ _CITATION = """
class Asdiv(Task): class Asdiv(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/asdiv") DATASET_PATH = inspect.getfile(lm_eval.datasets.asdiv.asdiv)
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum = "8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path = self.DATASET_PATH / "55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
os.remove(zip_path)
def _convert_standard(self, problem):
#TODO: include solution-type and formula
out_doc = {
"question" : problem.find('Question').text,
"body" : problem.find('Body').text,
"answer": problem.find('Answer').text
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
for pid, problem in enumerate(root.iter('Problem')):
out_doc = self._convert_standard(problem)
yield out_doc
def has_training_docs(self): def has_training_docs(self):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -81,13 +48,12 @@ class Asdiv(Task): ...@@ -81,13 +48,12 @@ class Asdiv(Task):
def training_docs(self): def training_docs(self):
raise NotImplementedError("This dataset has no training docs") raise NotImplementedError("This dataset has no training docs")
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def validation_docs(self):
data_xml_path = self.DATASET_PATH / "nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return self.load_docs(data_xml_path)
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
......
...@@ -10,9 +10,8 @@ grammars. ...@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp Homepage: https://github.com/alexwarstadt/blimp
""" """
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -32,19 +31,24 @@ _CITATION = """ ...@@ -32,19 +31,24 @@ _CITATION = """
""" """
class BlimpTask(HFTask): class BlimpTask(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "blimp" DATASET_PATH = "blimp"
def download(self): def has_training_docs(self):
super().download() return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation" # The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually # dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data. # trained on this data.
return self.dataset["train"]
self.data["validation"] = self.data["train"]
del self.data["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0 assert num_fewshot == 0
......
...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4. ...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -30,11 +29,30 @@ _CITATION = """ ...@@ -30,11 +29,30 @@ _CITATION = """
""" """
class CBTBase(HFTask): class CBTBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "cbt" DATASET_PATH = "cbt"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def detokenize(self, text): def detokenize(self, text):
text = text.replace(" '", "'") text = text.replace(" '", "'")
......
import datasets
from ..base import Task
class HFTask(Task):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
self.data = None
super().__init__()
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
"""Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.data["train"]))
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return map(self._convert_standard, self.data["validation"])
def test_docs(self):
if self.has_test_docs():
return map(self._convert_standard, self.data["test"])
def yesno(x):
if x:
return 'yes'
else:
return 'no'
...@@ -9,13 +9,11 @@ appear in a conversation. ...@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/ Homepage: https://stanfordnlp.github.io/coqa/
""" """
import os import inspect
import json
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -32,15 +30,8 @@ _CITATION = """ ...@@ -32,15 +30,8 @@ _CITATION = """
class CoQA(Task): class CoQA(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
def download(self): DATASET_NAME = None
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", local_file=coqa_train_filepath, expected_checksum="b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", local_file=coqa_dev_filepath, expected_checksum="dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -52,10 +43,10 @@ class CoQA(Task): ...@@ -52,10 +43,10 @@ class CoQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
pass pass
...@@ -64,9 +55,9 @@ class CoQA(Task): ...@@ -64,9 +55,9 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n' question = f"Q: {q}\n\n"
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
return doc_text return doc_text
...@@ -74,13 +65,13 @@ class CoQA(Task): ...@@ -74,13 +65,13 @@ class CoQA(Task):
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"] answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn) answers.append(answer_forturn)
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"] additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
...@@ -120,8 +111,8 @@ class CoQA(Task): ...@@ -120,8 +111,8 @@ class CoQA(Task):
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers'][turnid - 1]["input_text"] raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -148,7 +139,7 @@ class CoQA(Task): ...@@ -148,7 +139,7 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split('\n')[0]
......
...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop ...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`: Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
""" """
import json import inspect
import numpy as np import numpy as np
import re import re
import string import string
from best_download import download_file import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from pathlib import Path
from zipfile import ZipFile
_CITATION = """ _CITATION = """
...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) ...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path("data/drop") DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path = self.DATASET_PATH / "drop_dataset.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -63,29 +51,46 @@ class DROP(Task): ...@@ -63,29 +51,46 @@ class DROP(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _load_docs(self, docs): def training_docs(self):
for doc in docs: if self._training_docs is None:
for qa in doc["qa_pairs"]: self._training_docs = list(map(self._process_doc, self.dataset["train"]))
yield { return self._training_docs
"id": qa["query_id"],
"passage": doc["passage"], def validation_docs(self):
"question": qa["question"], return map(self._process_doc, self.dataset["validation"])
"answers": self.get_answers(qa),
} def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
return vas
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + qa.get("validated_answers", [])
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
continue continue
answers_set.add(answer) answers_set.add(answer)
answers.append(answer) answers.append(answer)
return answers return answers
@classmethod @classmethod
...@@ -99,14 +104,6 @@ class DROP(Task): ...@@ -99,14 +104,6 @@ class DROP(Task):
answer["date"]["month"], answer["date"]["month"],
answer["date"]["year"]]).strip(),) answer["date"]["year"]]).strip(),)
def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def validation_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
......
...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language. ...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/ Homepage: https://gluebenchmark.com/
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean, matthews_corrcoef, f1_score from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from . common import HFTask, yesno from lm_eval.utils import general_detokenize
from ..utils import general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE. # TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
...@@ -46,7 +45,7 @@ _CITATION = """ ...@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks # Single-Sentence Tasks
class CoLA(HFTask): class CoLA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -60,6 +59,14 @@ class CoLA(HFTask): ...@@ -60,6 +59,14 @@ class CoLA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
...@@ -90,7 +97,7 @@ class CoLA(HFTask): ...@@ -90,7 +97,7 @@ class CoLA(HFTask):
} }
class SST(HFTask): class SST(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -104,6 +111,14 @@ class SST(HFTask): ...@@ -104,6 +111,14 @@ class SST(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]), general_detokenize(doc["sentence"]),
...@@ -139,7 +154,7 @@ class SST(HFTask): ...@@ -139,7 +154,7 @@ class SST(HFTask):
# Inference Tasks # Inference Tasks
class MNLI(HFTask): class MNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -153,13 +168,18 @@ class MNLI(HFTask): ...@@ -153,13 +168,18 @@ class MNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_matched"] return self.dataset["validation_matched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_matched"] return self.dataset["test_matched"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
...@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI): ...@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI):
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_mismatched"] return self.dataset["validation_mismatched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_mismatched"] return self.dataset["test_mismatched"]
class QNLI(HFTask): class QNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -223,6 +243,14 @@ class QNLI(HFTask): ...@@ -223,6 +243,14 @@ class QNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"], doc["question"],
...@@ -258,7 +286,7 @@ class QNLI(HFTask): ...@@ -258,7 +286,7 @@ class QNLI(HFTask):
} }
class WNLI(HFTask): class WNLI(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -272,6 +300,14 @@ class WNLI(HFTask): ...@@ -272,6 +300,14 @@ class WNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -307,7 +343,7 @@ class WNLI(HFTask): ...@@ -307,7 +343,7 @@ class WNLI(HFTask):
} }
class RTE(HFTask): class RTE(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -321,6 +357,14 @@ class RTE(HFTask): ...@@ -321,6 +357,14 @@ class RTE(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -359,7 +403,7 @@ class RTE(HFTask): ...@@ -359,7 +403,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
class MRPC(HFTask): class MRPC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mrpc" DATASET_NAME = "mrpc"
...@@ -373,6 +417,14 @@ class MRPC(HFTask): ...@@ -373,6 +417,14 @@ class MRPC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence1"]),
...@@ -409,7 +461,7 @@ class MRPC(HFTask): ...@@ -409,7 +461,7 @@ class MRPC(HFTask):
} }
class QQP(HFTask): class QQP(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qqp" DATASET_NAME = "qqp"
...@@ -423,6 +475,14 @@ class QQP(HFTask): ...@@ -423,6 +475,14 @@ class QQP(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"], doc["question1"],
...@@ -459,7 +519,7 @@ class QQP(HFTask): ...@@ -459,7 +519,7 @@ class QQP(HFTask):
} }
class STSB(HFTask): class STSB(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "stsb" DATASET_NAME = "stsb"
...@@ -473,6 +533,17 @@ class STSB(HFTask): ...@@ -473,6 +533,17 @@ class STSB(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "sentence 1: {}\nsentence 2: {}\nAnswer:".format( return "sentence 1: {}\nsentence 2: {}\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
......
...@@ -16,10 +16,9 @@ model's sample/generation function. ...@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import inspect
import json
import re import re
from best_download import download_file import lm_eval.datasets.gsm8k.gsm8k
from pathlib import Path from pathlib import Path
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]" ...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class GradeSchoolMath8K(Task): class GradeSchoolMath8K(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path('data/gsm8k') DATASET_PATH = inspect.getfile(lm_eval.datasets.gsm8k.gsm8k)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits = [
{"name": "train", "checksum": "17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"},
{"name": "test", "checksum": "3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"},
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.jsonl"
url = f"{base_url}/{split['name']}.jsonl"
download_file(url, local_file=str(file), expected_checksum=split["checksum"])
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task): ...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, file):
return (json.loads(line) for line in open(file).read().splitlines())
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train.jsonl") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test.jsonl") return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
......
...@@ -8,7 +8,8 @@ even for highly specialized humans. ...@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/ Homepage: https://aghie.github.io/head-qa/
""" """
from . common import HFTask import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -24,9 +25,9 @@ _CITATION = """ ...@@ -24,9 +25,9 @@ _CITATION = """
""" """
class HeadQABase(HFTask, MultipleChoiceTask): class HeadQABase(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "head_qa" DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["qid"], "id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:", "query": "Question: " + doc["qtext"] + "\nAnswer:",
...@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
class HeadQAEn(HeadQABase): class HeadQAEn(HeadQABase):
DATASET_NAME = "en" DATASET_NAME = "en"
class HeadQAEs(HeadQABase): class HeadQAEs(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
# for backwards compatibility # for backwards compatibility
class HeadQAEsDeprecated(HeadQABase): class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
......
...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/ ...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class HellaSwag(HFTask, MultipleChoiceTask): class HellaSwag(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
@classmethod def training_docs(self):
def preprocess(cls, text): if self._training_docs is None:
text = text.strip() self._training_docs = list(map(self._process_doc, self.dataset["train"]))
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. return self._training_docs
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text) def validation_docs(self):
text = text.replace(" ", " ") return map(self._process_doc, self.dataset["validation"])
return text
def _convert_standard(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
...@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics ...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics
of the paper. of the paper.
Homepage: https://github.com/hendrycks/ethics Homepage: https://github.com/hendrycks/ethics
""" """
import abc import abc
import csv
import os
import random import random
import inspect
import lm_eval.datasets.hendrycks_ethics.hendrycks_ethics
import numpy as np import numpy as np
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean, yesno
from lm_eval.utils import sh
from .common import yesno
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -38,15 +35,8 @@ _CITATION = """ ...@@ -38,15 +35,8 @@ _CITATION = """
class Ethics(Task): class Ethics(Task):
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_ethics.hendrycks_ethics)
if not os.path.exists('data/ethics/done'): DATASET_NAME = None
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/ethics.tar", local_file="data/ethics.tar", expected_checksum="40acbf1ac0da79a2aabef394d58889136b8d38b05be09482006de2453fb06333")
sh("""
tar -xf data/ethics.tar -C data/
rm data/ethics.tar
touch data/ethics/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -57,30 +47,16 @@ class Ethics(Task): ...@@ -57,30 +47,16 @@ class Ethics(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
@abc.abstractmethod
def process_doc(self, doc):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
@abc.abstractmethod
def get_prefix(self):
"""returns string corresponding to file prefix"""
pass
# TODO: Figure out how to incorporate the Ethics `hard` test sets. # TODO: Figure out how to incorporate the Ethics `hard` test sets.
def training_docs(self): def training_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_train.csv") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_test.csv") return self.dataset["test"]
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -109,18 +85,13 @@ class Ethics(Task): ...@@ -109,18 +85,13 @@ class Ethics(Task):
class EthicsCM(Ethics): class EthicsCM(Ethics):
VERSION = 0 VERSION = 0
# Ignoring "ambiguous" extra dataset for now DATASET_NAME = "commonsense" # Ignoring "ambiguous" extra dataset for now
def get_prefix(self):
return "commonsense/cm"
def process_doc(self, doc):
return doc[1:]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc[1]) return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc["input"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -130,7 +101,7 @@ class EthicsCM(Ethics): ...@@ -130,7 +101,7 @@ class EthicsCM(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold "acc": pred == gold
} }
...@@ -148,19 +119,14 @@ class EthicsCM(Ethics): ...@@ -148,19 +119,14 @@ class EthicsCM(Ethics):
class EthicsDeontology(Ethics): class EthicsDeontology(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "deontology"
return "deontology/deontology"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
prompt = " ".join([doc[1], doc[2]]) prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt)
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics): ...@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics): ...@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics):
class EthicsJustice(Ethics): class EthicsJustice(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "justice"
return "justice/justice"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches later on & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc[1]) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -219,14 +181,15 @@ class EthicsJustice(Ethics): ...@@ -219,14 +181,15 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -247,17 +210,12 @@ class EthicsJustice(Ethics): ...@@ -247,17 +210,12 @@ class EthicsJustice(Ethics):
class EthicsUtilitarianismOriginal(Ethics): class EthicsUtilitarianismOriginal(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "utilitarianism"
return "utilitarianism/util"
def has_training_docs(self): def has_training_docs(self):
# Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting. # Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting.
return False return False
def process_doc(self, docs):
for doc in docs:
yield {"activity": doc[0], "baseline": doc[1], "rating": ""}
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5 # Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more." assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
...@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics):
class EthicsUtilitarianism(Ethics): class EthicsUtilitarianism(Ethics):
VERSION = 0
""" """
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared. This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots. This allows scaling to >5 shots.
""" """
VERSION = 0
DATASET_NAME = "utilitarianism"
def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]:
yield self._process_doc(doc, rnd)
def get_prefix(self): def validation_docs(self):
return "utilitarianism/util" raise NotImplementedError
def process_doc(self, docs): def test_docs(self):
rnd = random.Random() rnd = random.Random()
for doc in docs: for doc in self.dataset["test"]:
rnd.seed(doc[0]) yield self._process_doc(doc, rnd)
ordering = [0, 1]
rnd.shuffle(ordering) def _process_doc(self, doc, rnd):
yield { rnd.seed(doc["activity"])
"scenarios": [doc[ordering[0]], doc[ordering[1]]], scenarios = [doc["activity"], doc["baseline"]]
"label": int(ordering.index(0) == 0), # The correct scenario is always first ordering = [0, 1]
} rnd.shuffle(ordering)
return {
"scenarios": [scenarios[ordering[0]], scenarios[ordering[1]]],
# The correct scenario is always first
"label": int(ordering.index(0) == 0),
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format( return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format(
...@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics): ...@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics):
class EthicsVirtue(Ethics): class EthicsVirtue(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "virtue"
return "virtue/virtue"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def load_doc(self, filename): def _process_doc(self, doc):
with open(filename, newline='') as file: return doc
filereader = csv.reader(file)
return self.process_doc(list(filereader))
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(*doc[1].split(" [SEP] ")) return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(
doc["scenario"],
doc["trait"]
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics): ...@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct # Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)] em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))] em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
......
...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations. ...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math Homepage: https://github.com/hendrycks/math
""" """
import abc import inspect
import json import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.utils import sh
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,21 +25,8 @@ _CITATION = """ ...@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task): class Math(Task):
DATASET_PATH = Path('data/MATH') DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -53,28 +37,25 @@ class Math(Task): ...@@ -53,28 +37,25 @@ class Math(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info()) return map(self._load_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info()) return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:" return "Problem: " + doc["problem"] + "\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
...@@ -301,41 +282,34 @@ class Math(Task): ...@@ -301,41 +282,34 @@ class Math(Task):
class MathAlgebra(Math): class MathAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'algebra'
return 'algebra'
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'counting_and_probability'
return 'counting_and_probability'
class MathGeometry(Math): class MathGeometry(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'geometry'
return 'geometry'
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'intermediate_algebra'
return 'intermediate_algebra'
class MathNumberTheory(Math): class MathNumberTheory(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'number_theory'
return 'number_theory'
class MathPrealgebra(Math): class MathPrealgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'prealgebra'
return 'prealgebra'
class MathPrecalculus(Math): class MathPrecalculus(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'precalculus'
return 'precalculus'
...@@ -12,12 +12,7 @@ important shortcomings. ...@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test Homepage: https://github.com/hendrycks/test
""" """
import csv
import random
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -61,25 +56,15 @@ def create_task(subject): ...@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask): class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/") DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
self.subject = subject self.DATASET_NAME = subject
super().__init__() super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def validation_docs(self):
def format_example(doc, choices): return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
Choices: Choices:
...@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc[0] + "\nChoices:\n" prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)]) prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
choices = ['A', 'B', 'C', 'D'] keys = ['A', 'B', 'C', 'D']
return { return {
"query": format_example(doc, choices), "query": format_example(doc, keys),
"choices": doc[1:5], "choices": doc["choices"],
"gold": choices.index(doc[5]) "gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
} }
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename)) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) return rnd.sample(list(self._fewshot_docs), k)
......
...@@ -12,12 +12,10 @@ in the broader discourse. ...@@ -12,12 +12,10 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json import inspect
import lm_eval.datasets.lambada.lambada
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import os
_CITATION = """ _CITATION = """
...@@ -34,19 +32,7 @@ _CITATION = """ ...@@ -34,19 +32,7 @@ _CITATION = """
class LAMBADA(Task): class LAMBADA(Task):
VERSION = 0 VERSION = 0
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada.lambada)
sh("mkdir -p data/lambada")
try:
if not os.path.exists("data/lambada/lambada_test.jsonl"):
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
local_file="data/lambada/lambada_test.jsonl",
expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226 data/lambada/lambada_test.jsonl" | sha256sum --check')
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -61,9 +47,7 @@ class LAMBADA(Task): ...@@ -61,9 +47,7 @@ class LAMBADA(Task):
pass pass
def validation_docs(self): def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh: return self.dataset["validation"]
for line in fh:
yield json.loads(line)
def test_docs(self): def test_docs(self):
pass pass
......
...@@ -13,12 +13,7 @@ in the broader discourse. ...@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -35,6 +30,7 @@ _CITATION = """ ...@@ -35,6 +30,7 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA): class LAMBADA_cloze(LAMBADA):
VERSION = 0 VERSION = 0
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->" return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
......
...@@ -14,13 +14,6 @@ in the broader discourse. ...@@ -14,13 +14,6 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from . import lambada from . import lambada
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import json
from functools import partial
import os
_CITATION = """ _CITATION = """
...@@ -35,68 +28,37 @@ _CITATION = """ ...@@ -35,68 +28,37 @@ _CITATION = """
""" """
LANGS = ["en", "fr", "de", "it", "es"]
CHECKSUMS = {"en": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226",
"fr": "941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362",
"de": "51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e",
"it": "86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850",
"es": "ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"
}
class MultilingualLAMBADA(lambada.LAMBADA): class MultilingualLAMBADA(lambada.LAMBADA):
VERSION = 0 VERSION = 0
def __init__(self, lang=None):
self.LANG = lang
super().__init__()
def download(self):
sh("mkdir -p data/lambada")
f = f"data/lambada/lambada_test_{self.LANG}.jsonl"
url = f"http://eaidata.bmk.sh/data/lambada_test_{self.LANG}.jsonl"
try:
if not os.path.exists(f):
download_file(
url,
local_file=f,
expected_checksum=CHECKSUMS[self.LANG]
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh(f"wget {url} -O {f}")
sh(f'echo "{CHECKSUMS[self.LANG]} {f}" | sha256sum --check')
def validation_docs(self):
with open(f"data/lambada/lambada_test_{self.LANG}.jsonl") as fh:
for line in fh:
yield json.loads(line)
class MultilingualLAMBADAEN(MultilingualLAMBADA): class MultilingualLAMBADAEN(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'en'
super().__init__('en')
class MultilingualLAMBADAFR(MultilingualLAMBADA): class MultilingualLAMBADAFR(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'fr'
super().__init__('fr')
class MultilingualLAMBADADE(MultilingualLAMBADA): class MultilingualLAMBADADE(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'de'
super().__init__('de')
class MultilingualLAMBADAIT(MultilingualLAMBADA): class MultilingualLAMBADAIT(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'it'
super().__init__('it')
class MultilingualLAMBADAES(MultilingualLAMBADA): class MultilingualLAMBADAES(MultilingualLAMBADA):
def __init__(self): DATASET_NAME = 'es'
super().__init__('es')
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR,
MultilingualLAMBADADE, MultilingualLAMBADAIT,
MultilingualLAMBADAES]
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR, MultilingualLAMBADADE, MultilingualLAMBADAIT, MultilingualLAMBADAES]
def construct_tasks(): def construct_tasks():
tasks = {} tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES): for lang_class in LANG_CLASSES:
tasks[f"lambada_mt_{lang}"] = lang_class tasks[f"lambada_mt_{lang_class.DATASET_NAME}"] = lang_class
return tasks return tasks
...@@ -10,9 +10,9 @@ NLP setting. ...@@ -10,9 +10,9 @@ NLP setting.
Homepage: https://github.com/lgw863/LogiQA-dataset Homepage: https://github.com/lgw863/LogiQA-dataset
""" """
import inspect
import lm_eval.datasets.logiqa.logiqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from best_download import download_file
from pathlib import Path
_CITATION = """ _CITATION = """
...@@ -29,21 +29,8 @@ _CITATION = """ ...@@ -29,21 +29,8 @@ _CITATION = """
class LogiQA(MultipleChoiceTask): class LogiQA(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/logiqa") DATASET_PATH = inspect.getfile(lm_eval.datasets.logiqa.logiqa)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/lgw863/LogiQA-dataset/master"
splits = [
{"name": "Train", "checksum": "7d5bb1f58278e33b395744cd2ad8d7600faa0b3c4d615c659a44ec1181d759fa"},
{"name": "Eval", "checksum": "4c49e6753b7262c001506b9151135abf722247035ab075dad93acdea5789c01f"},
{"name": "Test", "checksum": "359acb78c37802208f7fde9e2f6574b8526527c63d6a336f90a53f1932cb4701"}
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.txt"
download_file(f"{base_url}/{split['name']}.txt", local_file=str(file), expected_checksum=split["checksum"])
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask): ...@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, choices): def format_example(doc, choices):
""" """
Passage: <passage> Passage: <passage>
...@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask): ...@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Passage: " + doc["passage"] + "\n" prompt = "Passage: " + doc["context"] + "\n"
prompt += "Question: " + doc["question"] + "\nChoices:\n" prompt += "Question: " + doc["question"] + "\nChoices:\n"
for choice, option in zip(choices, doc["options"]): for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n" prompt += f"{choice.upper()}. {option}\n"
...@@ -76,33 +74,8 @@ class LogiQA(MultipleChoiceTask): ...@@ -76,33 +74,8 @@ class LogiQA(MultipleChoiceTask):
return { return {
"query": format_example(doc, choices), "query": format_example(doc, choices),
"choices": doc["options"], "choices": doc["options"],
"gold": choices.index(doc["answerKey"]) "gold": choices.index(doc["label"])
} }
def _load_docs(self, filename):
def normalize(text):
return text.replace(".", ". ").strip()
with open(filename, 'r') as f:
docs = f.read().strip().split("\n\n")
for rawdoc in docs:
rawdoc = rawdoc.split("\n")
doc = {
"answerKey": rawdoc[0].strip(),
"passage": normalize(rawdoc[1]),
"question": normalize(rawdoc[2]),
"options": [normalize(option[2:]) for option in rawdoc[3:]]
}
yield self._convert_standard(doc)
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "Train.txt")
def validation_docs(self):
return self._load_docs(self.DATASET_PATH / "Eval.txt")
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "Test.txt")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment