Unverified Commit 4887d9d3 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #307 from Taekyoon/multilingual-ko-korsts

Add klue-sts task to eval Korean language task
parents 7064d6b9 4a1041c1
...@@ -50,6 +50,8 @@ from . import truthfulqa ...@@ -50,6 +50,8 @@ from . import truthfulqa
from . import blimp from . import blimp
from . import asdiv from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze
from . import klue
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -136,7 +138,6 @@ TASK_REGISTRY = { ...@@ -136,7 +138,6 @@ TASK_REGISTRY = {
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, "hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
...@@ -297,6 +298,14 @@ TASK_REGISTRY = { ...@@ -297,6 +298,14 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
# KLUE
"klue_sts": klue.STS
} }
......
...@@ -10,9 +10,8 @@ provided explanations. ...@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli" Homepage: "https://github.com/facebookresearch/anli"
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean from lm_eval.metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -31,7 +30,7 @@ _CITATION = """ ...@@ -31,7 +30,7 @@ _CITATION = """
""" """
class ANLIBase(HFTask): class ANLIBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
...@@ -49,16 +48,16 @@ class ANLIBase(HFTask): ...@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def training_docs(self): def training_docs(self):
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.data["train_r" + str(self.SPLIT)]) self._training_docs = list(self.dataset["train_r" + str(self.SPLIT)])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["dev_r" + str(self.SPLIT)] return self.dataset["dev_r" + str(self.SPLIT)]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_r" + str(self.SPLIT)] return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
...@@ -125,11 +124,14 @@ class ANLIBase(HFTask): ...@@ -125,11 +124,14 @@ class ANLIBase(HFTask):
"acc": True "acc": True
} }
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
SPLIT = 1 SPLIT = 1
class ANLIRound2(ANLIBase): class ANLIRound2(ANLIBase):
SPLIT = 2 SPLIT = 2
class ANLIRound3(ANLIBase): class ANLIRound3(ANLIBase):
SPLIT = 3 SPLIT = 3
...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -27,7 +26,7 @@ _CITATION = """ ...@@ -27,7 +26,7 @@ _CITATION = """
""" """
class ARCEasy(HFTask, MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one # NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters. # of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
......
...@@ -7,13 +7,10 @@ problem in natural language. ...@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data Homepage: https://github.com/openai/gpt-3/tree/master/data
""" """
import abc import inspect
import json import lm_eval.datasets.arithmetic.arithmetic
import os
from collections import namedtuple
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -31,33 +28,9 @@ _CITATION = """ ...@@ -31,33 +28,9 @@ _CITATION = """
""" """
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task): class Arithmetic(Task):
VERSION = 0 VERSION = 0
directory = 'data/arithmetic/' DATASET_PATH = inspect.getfile(lm_eval.datasets.arithmetic.arithmetic)
def __init__(self):
super().__init__()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
self.set_docs()
@abc.abstractmethod
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -72,25 +45,19 @@ class Arithmetic(Task): ...@@ -72,25 +45,19 @@ class Arithmetic(Task):
return NotImplemented return NotImplemented
def validation_docs(self): def validation_docs(self):
return self._docs return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc.context return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc.completion return doc["completion"]
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion) ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -111,41 +78,40 @@ class Arithmetic(Task): ...@@ -111,41 +78,40 @@ class Arithmetic(Task):
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2da"
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
class Arithmetic2DMinus(Arithmetic): class Arithmetic2DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2ds"
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
class Arithmetic3DPlus(Arithmetic): class Arithmetic3DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3da"
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
class Arithmetic3DMinus(Arithmetic): class Arithmetic3DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3ds"
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
class Arithmetic4DPlus(Arithmetic): class Arithmetic4DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4da"
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
class Arithmetic4DMinus(Arithmetic): class Arithmetic4DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4ds"
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
class Arithmetic5DPlus(Arithmetic): class Arithmetic5DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5da"
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
class Arithmetic5DMinus(Arithmetic): class Arithmetic5DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5ds"
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
class Arithmetic2DMultiplication(Arithmetic): class Arithmetic2DMultiplication(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2dm"
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
class Arithmetic1DComposite(Arithmetic): class Arithmetic1DComposite(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_1dc"
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation. ...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset Homepage: https://github.com/chaochun/nlu-asdiv-dataset
""" """
from lm_eval.base import Task import inspect
from pathlib import Path import lm_eval.datasets.asdiv.asdiv
from best_download import download_file from lm_eval.base import rf, Task
import xml.etree.ElementTree as ET from lm_eval.metrics import mean
from lm_eval.base import rf
from lm_eval.metrics import mean,perplexity
import numpy as np
from zipfile import ZipFile
import os
_CITATION = """ _CITATION = """
...@@ -39,39 +34,11 @@ _CITATION = """ ...@@ -39,39 +34,11 @@ _CITATION = """
class Asdiv(Task): class Asdiv(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/asdiv") DATASET_PATH = inspect.getfile(lm_eval.datasets.asdiv.asdiv)
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum = "8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path = self.DATASET_PATH / "55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
os.remove(zip_path)
def _convert_standard(self, problem):
#TODO: include solution-type and formula
out_doc = {
"question" : problem.find('Question').text,
"body" : problem.find('Body').text,
"answer": problem.find('Answer').text
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
for pid, problem in enumerate(root.iter('Problem')):
out_doc = self._convert_standard(problem)
yield out_doc
def has_training_docs(self): def has_training_docs(self):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -81,13 +48,12 @@ class Asdiv(Task): ...@@ -81,13 +48,12 @@ class Asdiv(Task):
def training_docs(self): def training_docs(self):
raise NotImplementedError("This dataset has no training docs") raise NotImplementedError("This dataset has no training docs")
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def validation_docs(self):
data_xml_path = self.DATASET_PATH / "nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return self.load_docs(data_xml_path)
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
......
...@@ -10,9 +10,8 @@ grammars. ...@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp Homepage: https://github.com/alexwarstadt/blimp
""" """
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -32,19 +31,24 @@ _CITATION = """ ...@@ -32,19 +31,24 @@ _CITATION = """
""" """
class BlimpTask(HFTask): class BlimpTask(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "blimp" DATASET_PATH = "blimp"
def download(self): def has_training_docs(self):
super().download() return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation" # The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually # dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data. # trained on this data.
return self.dataset["train"]
self.data["validation"] = self.data["train"]
del self.data["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0 assert num_fewshot == 0
......
...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4. ...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -30,11 +29,30 @@ _CITATION = """ ...@@ -30,11 +29,30 @@ _CITATION = """
""" """
class CBTBase(HFTask): class CBTBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "cbt" DATASET_PATH = "cbt"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def detokenize(self, text): def detokenize(self, text):
text = text.replace(" '", "'") text = text.replace(" '", "'")
......
import datasets
from ..base import Task
class HFTask(Task):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
self.data = None
super().__init__()
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
"""Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.data["train"]))
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return map(self._convert_standard, self.data["validation"])
def test_docs(self):
if self.has_test_docs():
return map(self._convert_standard, self.data["test"])
def yesno(x):
if x:
return 'yes'
else:
return 'no'
...@@ -9,13 +9,11 @@ appear in a conversation. ...@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/ Homepage: https://stanfordnlp.github.io/coqa/
""" """
import os import inspect
import json
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -32,15 +30,8 @@ _CITATION = """ ...@@ -32,15 +30,8 @@ _CITATION = """
class CoQA(Task): class CoQA(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
def download(self): DATASET_NAME = None
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", local_file=coqa_train_filepath, expected_checksum="b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", local_file=coqa_dev_filepath, expected_checksum="dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -52,10 +43,10 @@ class CoQA(Task): ...@@ -52,10 +43,10 @@ class CoQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
pass pass
...@@ -64,9 +55,9 @@ class CoQA(Task): ...@@ -64,9 +55,9 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n' question = f"Q: {q}\n\n"
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
return doc_text return doc_text
...@@ -74,13 +65,13 @@ class CoQA(Task): ...@@ -74,13 +65,13 @@ class CoQA(Task):
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"] answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn) answers.append(answer_forturn)
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"] additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
...@@ -120,8 +111,8 @@ class CoQA(Task): ...@@ -120,8 +111,8 @@ class CoQA(Task):
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers'][turnid - 1]["input_text"] raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -148,7 +139,7 @@ class CoQA(Task): ...@@ -148,7 +139,7 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split('\n')[0]
......
...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop ...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`: Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
""" """
import json import inspect
import numpy as np import numpy as np
import re import re
import string import string
from best_download import download_file import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from pathlib import Path
from zipfile import ZipFile
_CITATION = """ _CITATION = """
...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) ...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path("data/drop") DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path = self.DATASET_PATH / "drop_dataset.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -63,29 +51,46 @@ class DROP(Task): ...@@ -63,29 +51,46 @@ class DROP(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _load_docs(self, docs): def training_docs(self):
for doc in docs: if self._training_docs is None:
for qa in doc["qa_pairs"]: self._training_docs = list(map(self._process_doc, self.dataset["train"]))
yield { return self._training_docs
"id": qa["query_id"],
"passage": doc["passage"], def validation_docs(self):
"question": qa["question"], return map(self._process_doc, self.dataset["validation"])
"answers": self.get_answers(qa),
} def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
return vas
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + qa.get("validated_answers", [])
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
continue continue
answers_set.add(answer) answers_set.add(answer)
answers.append(answer) answers.append(answer)
return answers return answers
@classmethod @classmethod
...@@ -99,14 +104,6 @@ class DROP(Task): ...@@ -99,14 +104,6 @@ class DROP(Task):
answer["date"]["month"], answer["date"]["month"],
answer["date"]["year"]]).strip(),) answer["date"]["year"]]).strip(),)
def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def validation_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
......
...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language. ...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/ Homepage: https://gluebenchmark.com/
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean, matthews_corrcoef, f1_score from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from . common import HFTask, yesno from lm_eval.utils import general_detokenize
from ..utils import general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE. # TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
...@@ -46,7 +45,7 @@ _CITATION = """ ...@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks # Single-Sentence Tasks
class CoLA(HFTask): class CoLA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -60,6 +59,14 @@ class CoLA(HFTask): ...@@ -60,6 +59,14 @@ class CoLA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
...@@ -90,7 +97,7 @@ class CoLA(HFTask): ...@@ -90,7 +97,7 @@ class CoLA(HFTask):
} }
class SST(HFTask): class SST(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -104,6 +111,14 @@ class SST(HFTask): ...@@ -104,6 +111,14 @@ class SST(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]), general_detokenize(doc["sentence"]),
...@@ -139,7 +154,7 @@ class SST(HFTask): ...@@ -139,7 +154,7 @@ class SST(HFTask):
# Inference Tasks # Inference Tasks
class MNLI(HFTask): class MNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -153,13 +168,18 @@ class MNLI(HFTask): ...@@ -153,13 +168,18 @@ class MNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_matched"] return self.dataset["validation_matched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_matched"] return self.dataset["test_matched"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
...@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI): ...@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI):
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_mismatched"] return self.dataset["validation_mismatched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_mismatched"] return self.dataset["test_mismatched"]
class QNLI(HFTask): class QNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -223,6 +243,14 @@ class QNLI(HFTask): ...@@ -223,6 +243,14 @@ class QNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"], doc["question"],
...@@ -258,7 +286,7 @@ class QNLI(HFTask): ...@@ -258,7 +286,7 @@ class QNLI(HFTask):
} }
class WNLI(HFTask): class WNLI(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -272,6 +300,14 @@ class WNLI(HFTask): ...@@ -272,6 +300,14 @@ class WNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -307,7 +343,7 @@ class WNLI(HFTask): ...@@ -307,7 +343,7 @@ class WNLI(HFTask):
} }
class RTE(HFTask): class RTE(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -321,6 +357,14 @@ class RTE(HFTask): ...@@ -321,6 +357,14 @@ class RTE(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -359,7 +403,7 @@ class RTE(HFTask): ...@@ -359,7 +403,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
class MRPC(HFTask): class MRPC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mrpc" DATASET_NAME = "mrpc"
...@@ -373,6 +417,14 @@ class MRPC(HFTask): ...@@ -373,6 +417,14 @@ class MRPC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence1"]),
...@@ -409,7 +461,7 @@ class MRPC(HFTask): ...@@ -409,7 +461,7 @@ class MRPC(HFTask):
} }
class QQP(HFTask): class QQP(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qqp" DATASET_NAME = "qqp"
...@@ -423,6 +475,14 @@ class QQP(HFTask): ...@@ -423,6 +475,14 @@ class QQP(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"], doc["question1"],
...@@ -459,7 +519,7 @@ class QQP(HFTask): ...@@ -459,7 +519,7 @@ class QQP(HFTask):
} }
class STSB(HFTask): class STSB(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "stsb" DATASET_NAME = "stsb"
...@@ -473,6 +533,17 @@ class STSB(HFTask): ...@@ -473,6 +533,17 @@ class STSB(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "sentence 1: {}\nsentence 2: {}\nAnswer:".format( return "sentence 1: {}\nsentence 2: {}\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
......
...@@ -16,10 +16,9 @@ model's sample/generation function. ...@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import inspect
import json
import re import re
from best_download import download_file import lm_eval.datasets.gsm8k.gsm8k
from pathlib import Path from pathlib import Path
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]" ...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class GradeSchoolMath8K(Task): class GradeSchoolMath8K(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path('data/gsm8k') DATASET_PATH = inspect.getfile(lm_eval.datasets.gsm8k.gsm8k)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits = [
{"name": "train", "checksum": "17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"},
{"name": "test", "checksum": "3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"},
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.jsonl"
url = f"{base_url}/{split['name']}.jsonl"
download_file(url, local_file=str(file), expected_checksum=split["checksum"])
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task): ...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, file):
return (json.loads(line) for line in open(file).read().splitlines())
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train.jsonl") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test.jsonl") return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
......
...@@ -8,7 +8,8 @@ even for highly specialized humans. ...@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/ Homepage: https://aghie.github.io/head-qa/
""" """
from . common import HFTask import inspect
import lm_eval.datasets.headqa.headqa
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -24,9 +25,9 @@ _CITATION = """ ...@@ -24,9 +25,9 @@ _CITATION = """
""" """
class HeadQABase(HFTask, MultipleChoiceTask): class HeadQABase(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "head_qa" DATASET_PATH = inspect.getfile(lm_eval.datasets.headqa.headqa)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
out_doc = { out_doc = {
"id": doc["qid"], "id": doc["qid"],
"query": "Question: " + doc["qtext"] + "\nAnswer:", "query": "Question: " + doc["qtext"] + "\nAnswer:",
...@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask): ...@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
class HeadQAEn(HeadQABase): class HeadQAEn(HeadQABase):
DATASET_NAME = "en" DATASET_NAME = "en"
class HeadQAEs(HeadQABase): class HeadQAEs(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
# for backwards compatibility # for backwards compatibility
class HeadQAEsDeprecated(HeadQABase): class HeadQAEsDeprecated(HeadQABase):
DATASET_NAME = "es" DATASET_NAME = "es"
......
...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/ ...@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -28,7 +27,7 @@ _CITATION = """ ...@@ -28,7 +27,7 @@ _CITATION = """
""" """
class HellaSwag(HFTask, MultipleChoiceTask): class HellaSwag(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
@classmethod def training_docs(self):
def preprocess(cls, text): if self._training_docs is None:
text = text.strip() self._training_docs = list(map(self._process_doc, self.dataset["train"]))
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. return self._training_docs
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text) def validation_docs(self):
text = text.replace(" ", " ") return map(self._process_doc, self.dataset["validation"])
return text
def _convert_standard(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
...@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask): ...@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics ...@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics
of the paper. of the paper.
Homepage: https://github.com/hendrycks/ethics Homepage: https://github.com/hendrycks/ethics
""" """
import abc import abc
import csv
import os
import random import random
import inspect
import lm_eval.datasets.hendrycks_ethics.hendrycks_ethics
import numpy as np import numpy as np
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean, yesno
from lm_eval.utils import sh
from .common import yesno
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -38,15 +35,8 @@ _CITATION = """ ...@@ -38,15 +35,8 @@ _CITATION = """
class Ethics(Task): class Ethics(Task):
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_ethics.hendrycks_ethics)
if not os.path.exists('data/ethics/done'): DATASET_NAME = None
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/ethics.tar", local_file="data/ethics.tar", expected_checksum="40acbf1ac0da79a2aabef394d58889136b8d38b05be09482006de2453fb06333")
sh("""
tar -xf data/ethics.tar -C data/
rm data/ethics.tar
touch data/ethics/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -57,30 +47,16 @@ class Ethics(Task): ...@@ -57,30 +47,16 @@ class Ethics(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
@abc.abstractmethod
def process_doc(self, doc):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return self.process_doc(list(filereader))
@abc.abstractmethod
def get_prefix(self):
"""returns string corresponding to file prefix"""
pass
# TODO: Figure out how to incorporate the Ethics `hard` test sets. # TODO: Figure out how to incorporate the Ethics `hard` test sets.
def training_docs(self): def training_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_train.csv") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self.load_doc(f"data/ethics/{self.get_prefix()}_test.csv") return self.dataset["test"]
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -109,18 +85,13 @@ class Ethics(Task): ...@@ -109,18 +85,13 @@ class Ethics(Task):
class EthicsCM(Ethics): class EthicsCM(Ethics):
VERSION = 0 VERSION = 0
# Ignoring "ambiguous" extra dataset for now DATASET_NAME = "commonsense" # Ignoring "ambiguous" extra dataset for now
def get_prefix(self):
return "commonsense/cm"
def process_doc(self, doc):
return doc[1:]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc[1]) return "{}\nQuestion: Is this wrong?\nAnswer:".format(doc["input"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -130,7 +101,7 @@ class EthicsCM(Ethics): ...@@ -130,7 +101,7 @@ class EthicsCM(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold "acc": pred == gold
} }
...@@ -148,19 +119,14 @@ class EthicsCM(Ethics): ...@@ -148,19 +119,14 @@ class EthicsCM(Ethics):
class EthicsDeontology(Ethics): class EthicsDeontology(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "deontology"
return "deontology/deontology"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
prompt = " ".join([doc[1], doc[2]]) prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt)
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics): ...@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics): ...@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics):
class EthicsJustice(Ethics): class EthicsJustice(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "justice"
return "justice/justice"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches later on & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc[1]) return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
target = ["unreasonable", "reasonable"][int(doc[0])] target = ["unreasonable", "reasonable"][int(doc["label"])]
return " {}".format(target) return " {}".format(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -219,14 +181,15 @@ class EthicsJustice(Ethics): ...@@ -219,14 +181,15 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
...@@ -247,17 +210,12 @@ class EthicsJustice(Ethics): ...@@ -247,17 +210,12 @@ class EthicsJustice(Ethics):
class EthicsUtilitarianismOriginal(Ethics): class EthicsUtilitarianismOriginal(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "utilitarianism"
return "utilitarianism/util"
def has_training_docs(self): def has_training_docs(self):
# Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting. # Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting.
return False return False
def process_doc(self, docs):
for doc in docs:
yield {"activity": doc[0], "baseline": doc[1], "rating": ""}
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5 # Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more." assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
...@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics):
class EthicsUtilitarianism(Ethics): class EthicsUtilitarianism(Ethics):
VERSION = 0
""" """
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared. This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots. This allows scaling to >5 shots.
""" """
VERSION = 0
DATASET_NAME = "utilitarianism"
def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]:
yield self._process_doc(doc, rnd)
def get_prefix(self): def validation_docs(self):
return "utilitarianism/util" raise NotImplementedError
def process_doc(self, docs): def test_docs(self):
rnd = random.Random() rnd = random.Random()
for doc in docs: for doc in self.dataset["test"]:
rnd.seed(doc[0]) yield self._process_doc(doc, rnd)
ordering = [0, 1]
rnd.shuffle(ordering) def _process_doc(self, doc, rnd):
yield { rnd.seed(doc["activity"])
"scenarios": [doc[ordering[0]], doc[ordering[1]]], scenarios = [doc["activity"], doc["baseline"]]
"label": int(ordering.index(0) == 0), # The correct scenario is always first ordering = [0, 1]
} rnd.shuffle(ordering)
return {
"scenarios": [scenarios[ordering[0]], scenarios[ordering[1]]],
# The correct scenario is always first
"label": int(ordering.index(0) == 0),
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format( return "Scenario 1: {}\nScenario 2: {}\nQuestion: Is Scenario 1 preferrable?\nAnswer:".format(
...@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics): ...@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics):
class EthicsVirtue(Ethics): class EthicsVirtue(Ethics):
VERSION = 0 VERSION = 0
def get_prefix(self): DATASET_NAME = "virtue"
return "virtue/virtue"
def process_doc(self, doc):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return [x + [i] for i, x in enumerate(doc[1:])]
def load_doc(self, filename): def _process_doc(self, doc):
with open(filename, newline='') as file: return doc
filereader = csv.reader(file)
return self.process_doc(list(filereader))
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(*doc[1].split(" [SEP] ")) return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(
doc["scenario"],
doc["trait"]
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(yesno(int(doc[0]))) return " {}".format(yesno(int(doc["label"])))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
...@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics): ...@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc[0])) gold = bool(int(doc["label"]))
return { return {
"acc": pred == gold, "acc": pred == gold,
"em": [doc[-1], pred == gold] "em": [doc["group_id"], pred == gold]
} }
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct # Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)] em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))] em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
......
...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations. ...@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math Homepage: https://github.com/hendrycks/math
""" """
import abc import inspect
import json import lm_eval.datasets.hendrycks_math.hendrycks_math
from lm_eval.utils import sh
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,21 +25,8 @@ _CITATION = """ ...@@ -28,21 +25,8 @@ _CITATION = """
class Math(Task): class Math(Task):
DATASET_PATH = Path('data/MATH') DATASET_PATH = inspect.getfile(lm_eval.datasets.hendrycks_math.hendrycks_math)
DATASET_NAME = None
def download(self):
if not (self.DATASET_PATH / 'test').exists() or not (self.DATASET_PATH / 'done').exists():
sh(f"mkdir -p {self.DATASET_PATH}")
download_file("https://people.eecs.berkeley.edu/~hendrycks/MATH.tar", local_file=f"{self.DATASET_PATH}.tar", expected_checksum="0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac")
sh(f"""
tar -xf {self.DATASET_PATH}.tar -C data/ && touch {self.DATASET_PATH / 'done'}
rm {self.DATASET_PATH}.tar
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -53,28 +37,25 @@ class Math(Task): ...@@ -53,28 +37,25 @@ class Math(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, path):
for file in sorted(path.iterdir()):
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info()) return map(self._load_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info()) return map(self._load_doc, self.dataset["test"])
def _load_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:" return "Problem: " + doc["problem"] + "\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
...@@ -301,41 +282,34 @@ class Math(Task): ...@@ -301,41 +282,34 @@ class Math(Task):
class MathAlgebra(Math): class MathAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'algebra'
return 'algebra'
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'counting_and_probability'
return 'counting_and_probability'
class MathGeometry(Math): class MathGeometry(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'geometry'
return 'geometry'
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'intermediate_algebra'
return 'intermediate_algebra'
class MathNumberTheory(Math): class MathNumberTheory(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'number_theory'
return 'number_theory'
class MathPrealgebra(Math): class MathPrealgebra(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'prealgebra'
return 'prealgebra'
class MathPrecalculus(Math): class MathPrecalculus(Math):
VERSION = 1 VERSION = 1
def get_file_info(self): DATASET_NAME = 'precalculus'
return 'precalculus'
...@@ -12,12 +12,7 @@ important shortcomings. ...@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test Homepage: https://github.com/hendrycks/test
""" """
import csv
import random
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -61,25 +56,15 @@ def create_task(subject): ...@@ -61,25 +56,15 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask): class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/hendrycksTest/") DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
self.subject = subject self.DATASET_NAME = subject
super().__init__() super().__init__()
def download(self):
if not (self.DATASET_PATH / 'done').exists():
sh("mkdir -p data")
download_file("https://people.eecs.berkeley.edu/~hendrycks/data.tar", local_file="data/data.tar", expected_checksum="78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4")
sh("""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
""")
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def validation_docs(self):
def format_example(doc, choices): return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
Choices: Choices:
...@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc[0] + "\nChoices:\n" prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)]) prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
choices = ['A', 'B', 'C', 'D'] keys = ['A', 'B', 'C', 'D']
return { return {
"query": format_example(doc, choices), "query": format_example(doc, keys),
"choices": doc[1:5], "choices": doc["choices"],
"gold": choices.index(doc[5]) "gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
} }
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename)) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) return rnd.sample(list(self._fewshot_docs), k)
......
"""
NSMC:
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
_CITATION = """
@misc{park2021klue,
title={KLUE: Korean Language Understanding Evaluation},
author={Sungjoon Park and Jihyung Moon and Sungdong Kim and Won Ik Cho and Jiyoon Han and Jangwon Park and Chisung Song and Junseong Kim and Yongsook Song and Taehwan Oh and Joohong Lee and Juhyun Oh and Sungwon Lyu and Younghoon Jeong and Inkwon Lee and Sangwoo Seo and Dongjun Lee and Hyunwoo Kim and Myeonghwa Lee and Seongbo Jang and Seungwon Do and Sunkyoung Kim and Kyungtae Lim and Jongwon Lee and Kyumin Park and Jamin Shin and Seonghyun Kim and Lucy Park and Alice Oh and Jungwoo Ha and Kyunghyun Cho},
year={2021},
eprint={2105.09680},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
class STS(Task):
VERSION = 0
DATASET_PATH = "klue"
DATASET_NAME = "sts"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "질문: 문장 1과 문장 2는 서로 유사한 의미를 가지나요?\n문장 1:{}\n문장 2:{}\n정답:".format(
general_detokenize(doc["sentence1"]),
general_detokenize(doc["sentence2"])
)
def doc_to_target(self, doc):
return " {}".format({1: " 예", 0: " 아니"}[doc["labels"]["binary-label"]])
def construct_requests(self, doc, ctx):
ll_positive, _ = rf.loglikelihood(ctx, " 예")
ll_negative, _ = rf.loglikelihood(ctx, " 아니")
return ll_positive, ll_negative
def process_results(self, doc, results):
ll_positive, ll_negative = results
pred = ll_positive > ll_negative
gold = doc["labels"]["binary-label"]
return {
"acc": pred == gold,
"f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
\ No newline at end of file
...@@ -12,12 +12,10 @@ in the broader discourse. ...@@ -12,12 +12,10 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json import inspect
import lm_eval.datasets.lambada.lambada
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
import os
_CITATION = """ _CITATION = """
...@@ -34,19 +32,7 @@ _CITATION = """ ...@@ -34,19 +32,7 @@ _CITATION = """
class LAMBADA(Task): class LAMBADA(Task):
VERSION = 0 VERSION = 0
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada.lambada)
sh("mkdir -p data/lambada")
try:
if not os.path.exists("data/lambada/lambada_test.jsonl"):
download_file(
"http://eaidata.bmk.sh/data/lambada_test.jsonl",
local_file="data/lambada/lambada_test.jsonl",
expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
except:
# fallback - for some reason best_download doesnt work all the time here
sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226 data/lambada/lambada_test.jsonl" | sha256sum --check')
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -61,9 +47,7 @@ class LAMBADA(Task): ...@@ -61,9 +47,7 @@ class LAMBADA(Task):
pass pass
def validation_docs(self): def validation_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh: return self.dataset["validation"]
for line in fh:
yield json.loads(line)
def test_docs(self): def test_docs(self):
pass pass
......
...@@ -13,12 +13,7 @@ in the broader discourse. ...@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from lm_eval.tasks.lambada import LAMBADA from lm_eval.tasks.lambada import LAMBADA
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -35,6 +30,7 @@ _CITATION = """ ...@@ -35,6 +30,7 @@ _CITATION = """
class LAMBADA_cloze(LAMBADA): class LAMBADA_cloze(LAMBADA):
VERSION = 0 VERSION = 0
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] + " ____. ->" return doc['text'].rsplit(' ', 1)[0] + " ____. ->"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment