"official/nlp/docs/README.md" did not exist on "0c6367eead1d19ace161c6dd5e4834329ab461e9"
Unverified Commit 6caa0afd authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #300 from jon-tow/hf-dataset-refactor

Refactor `Task` downloading to use `HuggingFace.datasets`
parents 7064d6b9 9434722c
......@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli"
"""
import numpy as np
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -31,7 +30,7 @@ _CITATION = """
"""
class ANLIBase(HFTask):
class ANLIBase(Task):
VERSION = 0
DATASET_PATH = "anli"
DATASET_NAME = None
......@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.data["train_r" + str(self.SPLIT)])
self._training_docs = list(self.dataset["train_r" + str(self.SPLIT)])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.data["dev_r" + str(self.SPLIT)]
return self.dataset["dev_r" + str(self.SPLIT)]
def test_docs(self):
if self.has_test_docs():
return self.data["test_r" + str(self.SPLIT)]
return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
......@@ -125,11 +124,14 @@ class ANLIBase(HFTask):
"acc": True
}
class ANLIRound1(ANLIBase):
SPLIT = 1
class ANLIRound2(ANLIBase):
SPLIT = 2
class ANLIRound3(ANLIBase):
SPLIT = 3
......@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """
......@@ -27,7 +26,7 @@ _CITATION = """
"""
class ARCEasy(HFTask, MultipleChoiceTask):
class ARCEasy(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy"
......@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
......
......@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
import abc
import json
import os
from collections import namedtuple
import inspect
import lm_eval.datasets.arithmetic.arithmetic
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """
......@@ -31,33 +28,9 @@ _CITATION = """
"""
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task):
VERSION = 0
directory = 'data/arithmetic/'
def __init__(self):
super().__init__()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
self.set_docs()
@abc.abstractmethod
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
DATASET_PATH = inspect.getfile(lm_eval.datasets.arithmetic.arithmetic)
def has_training_docs(self):
return False
......@@ -72,25 +45,19 @@ class Arithmetic(Task):
return NotImplemented
def validation_docs(self):
return self._docs
return self.dataset["validation"]
def test_docs(self):
return NotImplemented
def doc_to_text(self, doc):
return doc.context
return doc["context"]
def doc_to_target(self, doc):
return doc.completion
return doc["completion"]
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction
def process_results(self, doc, results):
......@@ -111,41 +78,40 @@ class Arithmetic(Task):
class Arithmetic2DPlus(Arithmetic):
def get_file_download_info(self):
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
DATASET_NAME = "arithmetic_2da"
class Arithmetic2DMinus(Arithmetic):
def get_file_download_info(self):
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
DATASET_NAME = "arithmetic_2ds"
class Arithmetic3DPlus(Arithmetic):
def get_file_download_info(self):
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
DATASET_NAME = "arithmetic_3da"
class Arithmetic3DMinus(Arithmetic):
def get_file_download_info(self):
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
DATASET_NAME = "arithmetic_3ds"
class Arithmetic4DPlus(Arithmetic):
def get_file_download_info(self):
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
DATASET_NAME = "arithmetic_4da"
class Arithmetic4DMinus(Arithmetic):
def get_file_download_info(self):
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
DATASET_NAME = "arithmetic_4ds"
class Arithmetic5DPlus(Arithmetic):
def get_file_download_info(self):
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
DATASET_NAME = "arithmetic_5da"
class Arithmetic5DMinus(Arithmetic):
def get_file_download_info(self):
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
DATASET_NAME = "arithmetic_5ds"
class Arithmetic2DMultiplication(Arithmetic):
def get_file_download_info(self):
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
DATASET_NAME = "arithmetic_2dm"
class Arithmetic1DComposite(Arithmetic):
def get_file_download_info(self):
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
DATASET_NAME = "arithmetic_1dc"
......@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset
"""
from lm_eval.base import Task
from pathlib import Path
from best_download import download_file
import xml.etree.ElementTree as ET
from lm_eval.base import rf
from lm_eval.metrics import mean,perplexity
import numpy as np
from zipfile import ZipFile
import os
import inspect
import lm_eval.datasets.asdiv.asdiv
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
......@@ -39,39 +34,11 @@ _CITATION = """
class Asdiv(Task):
VERSION = 0
DATASET_PATH = Path("data/asdiv")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum = "8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path = self.DATASET_PATH / "55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
os.remove(zip_path)
def _convert_standard(self, problem):
#TODO: include solution-type and formula
out_doc = {
"question" : problem.find('Question').text,
"body" : problem.find('Body').text,
"answer": problem.find('Answer').text
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
for pid, problem in enumerate(root.iter('Problem')):
out_doc = self._convert_standard(problem)
yield out_doc
DATASET_PATH = inspect.getfile(lm_eval.datasets.asdiv.asdiv)
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
......@@ -81,13 +48,12 @@ class Asdiv(Task):
def training_docs(self):
raise NotImplementedError("This dataset has no training docs")
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError("This dataset has no test docs")
def validation_docs(self):
data_xml_path = self.DATASET_PATH / "nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return self.load_docs(data_xml_path)
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context(
......
......@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from lm_eval.base import rf
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """
......@@ -32,19 +31,24 @@ _CITATION = """
"""
class BlimpTask(HFTask):
class BlimpTask(Task):
VERSION = 0
DATASET_PATH = "blimp"
def download(self):
super().download()
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data.
self.data["validation"] = self.data["train"]
del self.data["train"]
return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0
......
......@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
"""
import numpy as np
from lm_eval.base import rf
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """
......@@ -30,11 +29,30 @@ _CITATION = """
"""
class CBTBase(HFTask):
class CBTBase(Task):
VERSION = 0
DATASET_PATH = "cbt"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def detokenize(self, text):
text = text.replace(" '", "'")
......
import datasets
from ..base import Task
class HFTask(Task):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
self.data = None
super().__init__()
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
"""Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.data["train"]))
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return map(self._convert_standard, self.data["validation"])
def test_docs(self):
if self.has_test_docs():
return map(self._convert_standard, self.data["test"])
def yesno(x):
if x:
return 'yes'
else:
return 'no'
......@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
"""
import os
import json
import inspect
import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest
from best_download import download_file
_CITATION = """
......@@ -32,15 +30,8 @@ _CITATION = """
class CoQA(Task):
VERSION = 1
def download(self):
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", local_file=coqa_train_filepath, expected_checksum="b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", local_file=coqa_dev_filepath, expected_checksum="dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -52,10 +43,10 @@ class CoQA(Task):
return False
def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
return self.dataset["train"]
def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
return self.dataset["validation"]
def test_docs(self):
pass
......@@ -64,9 +55,9 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:"
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
......@@ -74,13 +65,13 @@ class CoQA(Task):
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
......@@ -120,8 +111,8 @@ class CoQA(Task):
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx):
......@@ -148,7 +139,7 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"])
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0]
......
......@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
import json
import inspect
import numpy as np
import re
import string
from best_download import download_file
import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from pathlib import Path
from zipfile import ZipFile
_CITATION = """
......@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task):
VERSION = 1
DATASET_PATH = Path("data/drop")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path = self.DATASET_PATH / "drop_dataset.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -63,29 +51,46 @@ class DROP(Task):
def has_test_docs(self):
return False
def _load_docs(self, docs):
for doc in docs:
for qa in doc["qa_pairs"]:
yield {
"id": qa["query_id"],
"passage": doc["passage"],
"question": qa["question"],
"answers": self.get_answers(qa),
}
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + qa.get("validated_answers", [])
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod
......@@ -99,14 +104,6 @@ class DROP(Task):
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def validation_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
......
......@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/
"""
import numpy as np
from lm_eval.base import rf
from ..metrics import mean, matthews_corrcoef, f1_score
from . common import HFTask, yesno
from ..utils import general_detokenize
from lm_eval.base import rf, Task
from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
......@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks
class CoLA(HFTask):
class CoLA(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "cola"
......@@ -60,6 +59,14 @@ class CoLA(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
......@@ -90,7 +97,7 @@ class CoLA(HFTask):
}
class SST(HFTask):
class SST(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "sst2"
......@@ -104,6 +111,14 @@ class SST(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]),
......@@ -139,7 +154,7 @@ class SST(HFTask):
# Inference Tasks
class MNLI(HFTask):
class MNLI(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "mnli"
......@@ -153,13 +168,18 @@ class MNLI(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation_matched"]
return self.dataset["validation_matched"]
def test_docs(self):
if self.has_test_docs():
return self.data["test_matched"]
return self.dataset["test_matched"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
......@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI):
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation_mismatched"]
return self.dataset["validation_mismatched"]
def test_docs(self):
if self.has_test_docs():
return self.data["test_mismatched"]
return self.dataset["test_mismatched"]
class QNLI(HFTask):
class QNLI(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "qnli"
......@@ -223,6 +243,14 @@ class QNLI(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"],
......@@ -258,7 +286,7 @@ class QNLI(HFTask):
}
class WNLI(HFTask):
class WNLI(Task):
VERSION = 1
DATASET_PATH = "glue"
DATASET_NAME = "wnli"
......@@ -272,6 +300,14 @@ class WNLI(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"],
......@@ -307,7 +343,7 @@ class WNLI(HFTask):
}
class RTE(HFTask):
class RTE(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "rte"
......@@ -321,6 +357,14 @@ class RTE(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"],
......@@ -359,7 +403,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks
class MRPC(HFTask):
class MRPC(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "mrpc"
......@@ -373,6 +417,14 @@ class MRPC(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]),
......@@ -409,7 +461,7 @@ class MRPC(HFTask):
}
class QQP(HFTask):
class QQP(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "qqp"
......@@ -423,6 +475,14 @@ class QQP(HFTask):
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"],
......@@ -459,7 +519,7 @@ class QQP(HFTask):
}
class STSB(HFTask):
class STSB(Task):
VERSION = 0
DATASET_PATH = "glue"
DATASET_NAME = "stsb"
......@@ -473,6 +533,17 @@ class STSB(HFTask):
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return "sentence 1: {}\nsentence 2: {}\nAnswer:".format(
doc["sentence1"],
......
......@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import json
import inspect
import re
from best_download import download_file
import lm_eval.datasets.gsm8k.gsm8k
from pathlib import Path
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
......@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class GradeSchoolMath8K(Task):
VERSION = 0
DATASET_PATH = Path('data/gsm8k')
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits = [
{"name": "train", "checksum": "17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"},
{"name": "test", "checksum": "3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"},
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.jsonl"
url = f"{base_url}/{split['name']}.jsonl"
download_file(url, local_file=str(file), expected_checksum=split["checksum"])
DATASET_PATH = inspect.getfile(lm_eval.datasets.gsm8k.gsm8k)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def has_test_docs(self):
return True
def _load_docs(self, file):
return (json.loads(line) for line in open(file).read().splitlines())
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train.jsonl")
return self.dataset["train"]
def validation_docs(self):
raise NotImplementedError
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test.jsonl")
return self.dataset["test"]
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
......
......@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/
"""
from . common import HFTask
import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask
......@@ -24,9 +25,9 @@ _CITATION = """
"""
class HeadQABase(HFTask, MultipleChoiceTask):
class HeadQABase(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "head_qa"
DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self):
return True
......@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = {
"id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:",
......@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc):
return doc["query"]
class HeadQAEn(HeadQABase):
DATASET_NAME = "en"
class HeadQAEs(HeadQABase):
DATASET_NAME = "es"
# for backwards compatibility
class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es"
......
......@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
"""
import re
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """
......@@ -28,7 +27,7 @@ _CITATION = """
"""
class HellaSwag(HFTask, MultipleChoiceTask):
class HellaSwag(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hellaswag"
DATASET_NAME = None
......@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return False
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _convert_standard(self, doc):
def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx),
......@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask):
}
return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc):
return doc["query"]
......@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics
of the paper.
Homepage: https://github.com/hendrycks/ethics
"""
"""
import abc
import csv
import os
import random
import inspect
import lm_eval.datasets.hendrycks_ethics.hendrycks_ethics
import numpy as np
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval.utils import sh
from .common import yesno
from best_download import download_file
from lm_eval.metrics import mean, yesno
_CITATION = """
......@@ -38,15 +35,8 @@ _CITATION = """
class Ethics(Task):
def download(self):
if not os.path.exists('data/ethics/done'):
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/ethics.tar", local_file="data/ethics.tar", expected_checksum="40acbf1ac0da79a2aabef394d58889136b8d38b05be09482006de2453fb06333")
sh("""
tar -xf data/ethics.tar -C data/
rm data/ethics.tar
touch data/ethics/done
""")
DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_ethics.hendrycks_ethics)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -57,30 +47,16 @@ class Ethics(Task):
def has_test_docs(self):
return True
@abc.abstractmethod
def process_doc(self, doc):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
@abc.abstractmethod
def get_prefix(self):
"""returns string corresponding to file prefix"""
pass
# TODO: Figure out how to incorporate the Ethics `hard` test sets.
def training_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_train.csv")
return self.dataset["train"]
def validation_docs(self):
raise NotImplementedError
def test_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_test.csv")
return self.dataset["test"]
@abc.abstractmethod
def doc_to_text(self, doc):
......@@ -109,18 +85,13 @@ class Ethics(Task):
class EthicsCM(Ethics):
VERSION = 0
# Ignoring "ambiguous" extra dataset for now
def get_prefix(self):
return "commonsense/cm"
def process_doc(self, doc):
return doc[1:]
DATASET_NAME = "commonsense" # Ignoring "ambiguous" extra dataset for now
def doc_to_text(self, doc):
return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc[1])
return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc["input"])
def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0])))
return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
......@@ -130,7 +101,7 @@ class EthicsCM(Ethics):
def process_results(self, doc, results):
ll_yes, ll_no = results
pred = ll_yes > ll_no
gold = bool(int(doc[0]))
gold = bool(int(doc["label"]))
return {
"acc": pred == gold
}
......@@ -148,19 +119,14 @@ class EthicsCM(Ethics):
class EthicsDeontology(Ethics):
VERSION = 0
def get_prefix(self):
return "deontology/deontology"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
DATASET_NAME = "deontology"
def doc_to_text(self, doc):
prompt = " ".join([doc[1], doc[2]])
prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt)
def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])]
target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target)
def construct_requests(self, doc, ctx):
......@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = bool(int(doc[0]))
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc[-1], pred == gold]
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
......@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics):
class EthicsJustice(Ethics):
VERSION = 0
def get_prefix(self):
return "justice/justice"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches later on & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
DATASET_NAME = "justice"
def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc[1])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"])
def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])]
target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target)
def construct_requests(self, doc, ctx):
......@@ -219,14 +181,15 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = bool(int(doc[0]))
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc[-1], pred == gold]
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
......@@ -247,17 +210,12 @@ class EthicsJustice(Ethics):
class EthicsUtilitarianismOriginal(Ethics):
VERSION = 0
def get_prefix(self):
return "utilitarianism/util"
DATASET_NAME = "utilitarianism"
def has_training_docs(self):
# Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting.
return False
def process_doc(self, docs):
for doc in docs:
yield {"activity": doc[0], "baseline": doc[1], "rating": ""}
def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
......@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics):
class EthicsUtilitarianism(Ethics):
VERSION = 0
"""
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots.
"""
VERSION = 0
DATASET_NAME = "utilitarianism"
def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]:
yield self._process_doc(doc, rnd)
def get_prefix(self):
return "utilitarianism/util"
def validation_docs(self):
raise NotImplementedError
def process_doc(self, docs):
def test_docs(self):
rnd = random.Random()
for doc in docs:
rnd.seed(doc[0])
ordering = [0, 1]
rnd.shuffle(ordering)
yield {
"scenarios": [doc[ordering[0]], doc[ordering[1]]],
"label": int(ordering.index(0) == 0), # The correct scenario is always first
}
for doc in self.dataset["test"]:
yield self._process_doc(doc, rnd)
def _process_doc(self, doc, rnd):
rnd.seed(doc["activity"])
scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1]
rnd.shuffle(ordering)
return {
"scenarios": [scenarios[ordering[0]], scenarios[ordering[1]]],
# The correct scenario is always first
"label": int(ordering.index(0) == 0),
}
def doc_to_text(self, doc):
return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format(
......@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics):
class EthicsVirtue(Ethics):
VERSION = 0
def get_prefix(self):
return "virtue/virtue"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
DATASET_NAME = "virtue"
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
def _process_doc(self, doc):
return doc
def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(*doc[1].split(" [SEP] "))
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(
doc["scenario"],
doc["trait"]
)
def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0])))
return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
......@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics):
def process_results(self, doc, results):
ll_yes, ll_no = results
pred = ll_yes > ll_no
gold = bool(int(doc[0]))
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc[-1], pred == gold]
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
......
......@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math
"""
import abc
import json
from lm_eval.utils import sh
import inspect
import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.metrics import mean
from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """
......@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task):
DATASET_PATH = Path('data/MATH')
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -53,28 +37,25 @@ class Math(Task):
def has_test_docs(self):
return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info())
return map(self._load_doc, self.dataset["train"])
def validation_docs(self):
return NotImplemented
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info())
return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:"
def doc_to_target(self, doc):
return " " + doc["answer"]
return " " + doc["solution"]
def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"])
......@@ -301,41 +282,34 @@ class Math(Task):
class MathAlgebra(Math):
VERSION = 1
def get_file_info(self):
return 'algebra'
DATASET_NAME = 'algebra'
class MathCountingAndProbability(Math):
VERSION = 1
def get_file_info(self):
return 'counting_and_probability'
DATASET_NAME = 'counting_and_probability'
class MathGeometry(Math):
VERSION = 1
def get_file_info(self):
return 'geometry'
DATASET_NAME = 'geometry'
class MathIntermediateAlgebra(Math):
VERSION = 1
def get_file_info(self):
return 'intermediate_algebra'
DATASET_NAME = 'intermediate_algebra'
class MathNumberTheory(Math):
VERSION = 1
def get_file_info(self):
return 'number_theory'
DATASET_NAME = 'number_theory'
class MathPrealgebra(Math):
VERSION = 1
def get_file_info(self):
return 'prealgebra'
DATASET_NAME = 'prealgebra'
class MathPrecalculus(Math):
VERSION = 1
def get_file_info(self):
return 'precalculus'
DATASET_NAME = 'precalculus'
......@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test
"""
import csv
import random
from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """
......@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/")
DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject):
self.subject = subject
self.DATASET_NAME = subject
super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self):
return True
return False
def has_validation_docs(self):
return True
......@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def format_example(doc, choices):
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
......@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4>
Answer:
"""
prompt = "Question: " + doc[0] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)])
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:"
return prompt
choices = ['A', 'B', 'C', 'D']
keys = ['A', 'B', 'C', 'D']
return {
"query": format_example(doc, choices),
"choices": doc[1:5],
"gold": choices.index(doc[5])
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
}
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename))
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
......
......@@ -12,12 +12,10 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
import json
import inspect
import lm_eval.datasets.lambada.lambada
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import os
_CITATION = """
......@@ -34,19 +32,7 @@ _CITATION = """
class LAMBADA(Task):
VERSION = 0
def download(self):
sh("mkdir -p data/lambada")
try:
if not os.path.exists("data/lambada/lambada_test.jsonl"):
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
local_file="data/lambada/lambada_test.jsonl",
expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226 data/lambada/lambada_test.jsonl" | sha256sum --check')
DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada.lambada)
def has_training_docs(self):
return False
......@@ -61,9 +47,7 @@ class LAMBADA(Task):
pass
def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh:
for line in fh:
yield json.loads(line)
return self.dataset["validation"]
def test_docs(self):
pass
......
......@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """
......@@ -35,6 +30,7 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA):
VERSION = 0
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
......
......@@ -14,13 +14,6 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from . import lambada
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import json
from functools import partial
import os
_CITATION = """
......@@ -35,68 +28,37 @@ _CITATION = """
"""
LANGS = ["en", "fr", "de", "it", "es"]
CHECKSUMS = {"en": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226",
"fr": "941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362",
"de": "51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e",
"it": "86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850",
"es": "ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"
}
class MultilingualLAMBADA(lambada.LAMBADA):
VERSION = 0
def __init__(self, lang=None):
self.LANG = lang
super().__init__()
def download(self):
sh("mkdir -p data/lambada")
f = f"data/lambada/lambada_test_{self.LANG}.jsonl"
url = f"http://eaidata.bmk.sh/data/lambada_test_{self.LANG}.jsonl"
try:
if not os.path.exists(f):
download_file(
url,
local_file=f,
expected_checksum=CHECKSUMS[self.LANG]
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh(f"wget {url} -O {f}")
sh(f'echo "{CHECKSUMS[self.LANG]} {f}" | sha256sum --check')
def validation_docs(self):
with open(f"data/lambada/lambada_test_{self.LANG}.jsonl") as fh:
for line in fh:
yield json.loads(line)
class MultilingualLAMBADAEN(MultilingualLAMBADA):
def __init__(self):
super().__init__('en')
DATASET_NAME = 'en'
class MultilingualLAMBADAFR(MultilingualLAMBADA):
def __init__(self):
super().__init__('fr')
DATASET_NAME = 'fr'
class MultilingualLAMBADADE(MultilingualLAMBADA):
def __init__(self):
super().__init__('de')
DATASET_NAME = 'de'
class MultilingualLAMBADAIT(MultilingualLAMBADA):
def __init__(self):
super().__init__('it')
DATASET_NAME = 'it'
class MultilingualLAMBADAES(MultilingualLAMBADA):
def __init__(self):
super().__init__('es')
DATASET_NAME = 'es'
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR,
MultilingualLAMBADADE, MultilingualLAMBADAIT,
MultilingualLAMBADAES]
LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR, MultilingualLAMBADADE, MultilingualLAMBADAIT, MultilingualLAMBADAES]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"lambada_mt_{lang}"] = lang_class
for lang_class in LANG_CLASSES:
tasks[f"lambada_mt_{lang_class.DATASET_NAME}"] = lang_class
return tasks
......@@ -10,9 +10,9 @@ NLP setting.
Homepage: https://github.com/lgw863/LogiQA-dataset
"""
import inspect
import lm_eval.datasets.logiqa.logiqa
from lm_eval.base import MultipleChoiceTask
from best_download import download_file
from pathlib import Path
_CITATION = """
......@@ -29,21 +29,8 @@ _CITATION = """
class LogiQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = Path("data/logiqa")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/lgw863/LogiQA-dataset/master"
splits = [
{"name": "Train", "checksum": "7d5bb1f58278e33b395744cd2ad8d7600faa0b3c4d615c659a44ec1181d759fa"},
{"name": "Eval", "checksum": "4c49e6753b7262c001506b9151135abf722247035ab075dad93acdea5789c01f"},
{"name": "Test", "checksum": "359acb78c37802208f7fde9e2f6574b8526527c63d6a336f90a53f1932cb4701"}
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.txt"
download_file(f"{base_url}/{split['name']}.txt", local_file=str(file), expected_checksum=split["checksum"])
DATASET_PATH = inspect.getfile(lm_eval.datasets.logiqa.logiqa)
DATASET_NAME = None
def has_training_docs(self):
return True
......@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask):
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, choices):
"""
Passage: <passage>
......@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask):
D. <choice4>
Answer:
"""
prompt = "Passage: " + doc["passage"] + "\n"
prompt = "Passage: " + doc["context"] + "\n"
prompt += "Question: " + doc["question"] + "\nChoices:\n"
for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n"
......@@ -76,33 +74,8 @@ class LogiQA(MultipleChoiceTask):
return {
"query": format_example(doc, choices),
"choices": doc["options"],
"gold": choices.index(doc["answerKey"])
"gold": choices.index(doc["label"])
}
def _load_docs(self, filename):
def normalize(text):
return text.replace(".", ". ").strip()
with open(filename, 'r') as f:
docs = f.read().strip().split("\n\n")
for rawdoc in docs:
rawdoc = rawdoc.split("\n")
doc = {
"answerKey": rawdoc[0].strip(),
"passage": normalize(rawdoc[1]),
"question": normalize(rawdoc[2]),
"options": [normalize(option[2:]) for option in rawdoc[3:]]
}
yield self._convert_standard(doc)
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "Train.txt")
def validation_docs(self):
return self._load_docs(self.DATASET_PATH / "Eval.txt")
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "Test.txt")
def doc_to_text(self, doc):
return doc["query"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment