Unverified Commit 11f614b0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into task_doc

parents 0a6a9b7e e00d682f
...@@ -10,10 +10,9 @@ a teacher who answers the questions by providing short excerpts (spans) from the ...@@ -10,10 +10,9 @@ a teacher who answers the questions by providing short excerpts (spans) from the
Homepage: https://quac.ai/ Homepage: https://quac.ai/
""" """
import json import inspect
import os import lm_eval.datasets.quac.quac
from lm_eval.base import Task from lm_eval.base import Task
from ..utils import sh
_CITATION = """ _CITATION = """
...@@ -28,18 +27,8 @@ _CITATION = """ ...@@ -28,18 +27,8 @@ _CITATION = """
class QuAC(Task): class QuAC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = inspect.getfile(lm_eval.datasets.quac.quac)
def __init__(self): DATASET_NAME = None
super().__init__()
def download(self):
if not os.path.exists('data/quac'):
# TODO: convert to use best_download
sh("""
mkdir -p data/quac
wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json
wget https://s3.amazonaws.com/my89public/quac/val_v0.2.json -O data/quac/val_v0.2.json
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -51,31 +40,29 @@ class QuAC(Task): ...@@ -51,31 +40,29 @@ class QuAC(Task):
return False return False
def training_docs(self): def training_docs(self):
myjson = json.load(open('data/quac/train_v0.2.json'))['data'] if self._training_docs is None:
return self.load_doc(myjson) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self): def validation_docs(self):
myjson = json.load(open('data/quac/val_v0.2.json'))['data'] return map(self._process_doc, self.dataset["validation"])
return self.load_doc(myjson)
def test_docs(self): def test_docs(self):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def load_doc(self, myjson): def _process_doc(self, doc):
docs = [] doc["title"] = doc['title'] + ' - ' + doc['section_title']
for item in myjson: return doc
title = item['title'] + ' - ' + item['section_title']
paragraph = item['paragraphs'][0]['context'].replace("CANNOTANSWER", "")
qas = item['paragraphs'][0]['qas']
qa_pairs = [(qa['question'], qa['answers'][0]['text']) for qa in qas]
for (question, answer) in qa_pairs:
doc = { 'title': title, 'paragraph': paragraph, 'question': question, 'answer': answer }
docs.append(doc)
return docs
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: '
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['paragraph']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['answer'] return doc['answer']
...@@ -88,7 +75,7 @@ class QuAC(Task): ...@@ -88,7 +75,7 @@ class QuAC(Task):
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError('Evaluation not implemented')
......
...@@ -12,9 +12,8 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/ ...@@ -12,9 +12,8 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/
import collections import collections
import datasets import datasets
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean from lm_eval.metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -35,16 +34,14 @@ class each: ...@@ -35,16 +34,14 @@ class each:
return list(map(self.f, other)) return list(map(self.f, other))
class RACE(HFTask): class RACE(Task):
VERSION = 0 VERSION = 1
DATASET_PATH = "race" DATASET_PATH = "race"
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
assert datasets.__version__ == "1.15.1", "RACE requires datasets==1.15.1!"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -107,6 +104,12 @@ class RACE(HFTask): ...@@ -107,6 +104,12 @@ class RACE(HFTask):
text += self.last_problem(doc)['question'] text += self.last_problem(doc)['question']
return text return text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['article']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc)) return " " + self.get_answer_option(self.last_problem(doc))
......
...@@ -7,7 +7,8 @@ multiple-choice analogy questions; 5 choices per question. ...@@ -7,7 +7,8 @@ multiple-choice analogy questions; 5 choices per question.
Homepage: https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art) Homepage: https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art)
""" """
import os import inspect
import lm_eval.datasets.sat_analogies.sat_analogies
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
...@@ -25,20 +26,18 @@ _CITATION = """ ...@@ -25,20 +26,18 @@ _CITATION = """
""" """
class SATAnalogies(MultipleChoiceTask): class SATAnalogies(MultipleChoiceTask):
VERSION = 0 VERSION = 0
NEEDS_MANUAL_DL = True DATASET_PATH = inspect.getfile(lm_eval.datasets.sat_analogies.sat_analogies)
DATASET_NAME = None
def __init__(self):
super().__init__()
def download(self): def __init__(self, data_dir: str):
# We should be using a checksum here. """
# The canonical sha256 hash is below: SAT Analog Questions is not publicly available. You must request the data
# 9dece377d8d57253ef8c78370ff15de0bb1d9e90a82c815a67ba1e621e921bfc by emailing Peter Turney and then download it to a local directory path
which should be passed into the `data_dir` arg.
if not os.path.exists('data/sat/SAT-package-V3.txt'): """
raise NotImplementedError('SAT Analogies dataset is not provided. Follow instructions on https://aclweb.org/aclwiki/SAT_Analogy_Questions_(State_of_the_art) to locate.') super().__init__(data_dir=data_dir)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -51,38 +50,26 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -51,38 +50,26 @@ class SATAnalogies(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
return [] return []
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return [] return []
def validation_docs(self): def _process_doc(self, doc):
data = [] return {
'source': doc['source'],
with open("data/sat/SAT-package-V3.txt", "r") as f: 'query': doc['stem'].split(' ')[:2],
record = [] 'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]],
for line in f: 'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()),
line = line.strip() }
if len(line) == 0 and record:
data.append(record)
record = []
elif len(line) > 0 and line[0] == '#':
continue
else:
record.append(line)
data.append(record)
for record in data:
source = record[-8]
query = record[-7]
choices = record[-6:-1]
answer_key = record[-1]
doc = {
'source': source,
'query': query.split(' ')[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in choices],
'gold': ['a','b','c','d','e'].index(answer_key.strip()),
}
yield doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + "\n" + " ".join(doc["query"])
...@@ -9,11 +9,7 @@ with supporting evidence for the correct answer is provided. ...@@ -9,11 +9,7 @@ with supporting evidence for the correct answer is provided.
Homepage: https://allenai.org/data/sciq Homepage: https://allenai.org/data/sciq
""" """
import os
import json
import zipfile
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -28,17 +24,8 @@ _CITATION = """ ...@@ -28,17 +24,8 @@ _CITATION = """
class SciQ(MultipleChoiceTask): class SciQ(MultipleChoiceTask):
VERSION = 0 VERSION = 0
# Multiple languages and multiple years DATASET_PATH = "sciq"
def download(self): DATASET_NAME = None
if not os.path.exists('data/sciq'):
os.makedirs('data/sciq', exist_ok=True)
download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
local_file='data/sciq/SciQ.zip',
expected_checksum='7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -49,36 +36,38 @@ class SciQ(MultipleChoiceTask): ...@@ -49,36 +36,38 @@ class SciQ(MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
choices = [ choices = [
doc["distractor1"], doc["distractor1"],
doc["distractor2"], doc["distractor2"],
doc["distractor3"], doc["distractor3"],
doc["correct_answer"], doc["correct_answer"],
] ]
src = doc['support'] src = doc['support']
out_doc = { out_doc = {
"source" : src, "source": src,
"query" : doc['question'], "query": doc['question'],
"choices" : choices, "choices": choices,
"gold" : 3, "gold": 3,
} }
return out_doc return out_doc
def load_docs(self, textfilename):
with open(textfilename, 'r') as j:
docs = json.loads(j.read())
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
def validation_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
def test_docs(self):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip() return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
...@@ -15,9 +15,7 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/ ...@@ -15,9 +15,7 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets import datasets
from math import exp from math import exp
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import f1_score, mean
from . common import HFTask
from functools import partial from functools import partial
from packaging import version from packaging import version
...@@ -45,7 +43,7 @@ def _squad_agg(key, items): ...@@ -45,7 +43,7 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references)[key] return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask): class SQuAD2(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
DATASET_NAME = None DATASET_NAME = None
...@@ -63,14 +61,20 @@ class SQuAD2(HFTask): ...@@ -63,14 +61,20 @@ class SQuAD2(HFTask):
return False return False
def training_docs(self): def training_docs(self):
return self.data["train"] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.data["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['context']
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer_list = doc['answers']['text'] answer_list = doc['answers']['text']
if len(answer_list) > 0: if len(answer_list) > 0:
......
...@@ -8,8 +8,9 @@ to choose the correct ending to a four-sentence story. ...@@ -8,8 +8,9 @@ to choose the correct ending to a four-sentence story.
Homepage: https://cs.rochester.edu/nlp/rocstories/ Homepage: https://cs.rochester.edu/nlp/rocstories/
""" """
import csv import numpy as np
from lm_eval.base import Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
...@@ -34,11 +35,16 @@ _CITATION = """ ...@@ -34,11 +35,16 @@ _CITATION = """
class StoryCloze(Task): class StoryCloze(Task):
VERSION = 0 VERSION = 0
NEEDS_MANUAL_DL = True DATASET_PATH = "story_cloze"
DATASET_NAME = None
def download(self): def __init__(self, data_dir: str):
#TODO: replace with Eye link """
pass StoryCloze is not publicly available. You must download the data by
following https://cs.rochester.edu/nlp/rocstories/ and pass the folder
path into the `data_dir` arg.
"""
super().__init__(data_dir=data_dir)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -52,40 +58,57 @@ class StoryCloze(Task): ...@@ -52,40 +58,57 @@ class StoryCloze(Task):
def training_docs(self): def training_docs(self):
pass pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return list(filereader)
def validation_docs(self): def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv") return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([*doc[1:5]]) return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc[int(doc[-1]) - 4] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
# `- 1` because the `answer_right_ending` index is 1-based.
return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
raise NotImplementedError('Evaluation not implemented') lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
return lls
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -93,23 +116,36 @@ class StoryCloze(Task): ...@@ -93,23 +116,36 @@ class StoryCloze(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = doc["answer_right_ending"] - 1
raise NotImplementedError('Evaluation not implemented') acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
class StoryCloze2016(StoryCloze):
DATASET_NAME = "2016"
class StoryCloze2018(StoryCloze):
DATASET_NAME = "2018"
...@@ -12,10 +12,9 @@ TODO: WSC requires free-form generation. ...@@ -12,10 +12,9 @@ TODO: WSC requires free-form generation.
import numpy as np import numpy as np
import sklearn import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from . common import HFTask, yesno from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
from ..metrics import mean, acc_all, metric_max_over_ground_truths from lm_eval.utils import general_detokenize
from ..utils import general_detokenize
_CITATION = """ _CITATION = """
...@@ -33,7 +32,7 @@ _CITATION = """ ...@@ -33,7 +32,7 @@ _CITATION = """
""" """
class BoolQ(HFTask): class BoolQ(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -47,8 +46,22 @@ class BoolQ(HFTask): ...@@ -47,8 +46,22 @@ class BoolQ(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['passage']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + yesno(doc['label']) return " " + yesno(doc['label'])
...@@ -81,7 +94,7 @@ class BoolQ(HFTask): ...@@ -81,7 +94,7 @@ class BoolQ(HFTask):
} }
class CommitmentBank(HFTask): class CommitmentBank(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -95,6 +108,14 @@ class CommitmentBank(HFTask): ...@@ -95,6 +108,14 @@ class CommitmentBank(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
...@@ -148,7 +169,7 @@ class CommitmentBank(HFTask): ...@@ -148,7 +169,7 @@ class CommitmentBank(HFTask):
} }
class Copa(HFTask): class Copa(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "copa" DATASET_NAME = "copa"
...@@ -162,6 +183,14 @@ class Copa(HFTask): ...@@ -162,6 +183,14 @@ class Copa(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Drop the period # Drop the period
connector = { connector = {
...@@ -208,7 +237,7 @@ class Copa(HFTask): ...@@ -208,7 +237,7 @@ class Copa(HFTask):
return choice[0].lower() + choice[1:] return choice[0].lower() + choice[1:]
class MultiRC(HFTask): class MultiRC(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -222,6 +251,14 @@ class MultiRC(HFTask): ...@@ -222,6 +251,14 @@ class MultiRC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
...@@ -260,7 +297,7 @@ class MultiRC(HFTask): ...@@ -260,7 +297,7 @@ class MultiRC(HFTask):
} }
class ReCoRD(HFTask): class ReCoRD(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "record" DATASET_NAME = "record"
...@@ -279,13 +316,13 @@ class ReCoRD(HFTask): ...@@ -279,13 +316,13 @@ class ReCoRD(HFTask):
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
if self._training_docs is None: if self._training_docs is None:
self._training_docs = [] self._training_docs = []
for doc in self.data["train"]: for doc in self.dataset["train"]:
self._training_docs.append(self._process_doc(doc)) self._training_docs.append(self._process_doc(doc))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
# See: training_docs # See: training_docs
for doc in self.data["validation"]: for doc in self.dataset["validation"]:
yield self._process_doc(doc) yield self._process_doc(doc)
@classmethod @classmethod
...@@ -349,7 +386,7 @@ class ReCoRD(HFTask): ...@@ -349,7 +386,7 @@ class ReCoRD(HFTask):
} }
class WordsInContext(HFTask): class WordsInContext(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wic" DATASET_NAME = "wic"
...@@ -363,6 +400,14 @@ class WordsInContext(HFTask): ...@@ -363,6 +400,14 @@ class WordsInContext(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format( " two sentences above?\nAnswer:".format(
...@@ -401,7 +446,7 @@ class WordsInContext(HFTask): ...@@ -401,7 +446,7 @@ class WordsInContext(HFTask):
} }
class SGWinogradSchemaChallenge(HFTask): class SGWinogradSchemaChallenge(Task):
VERSION = 0 VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task. # binary version of the task.
...@@ -423,11 +468,14 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -423,11 +468,14 @@ class SGWinogradSchemaChallenge(HFTask):
# GPT-3 Paper's format only uses positive examples for fewshot "training" # GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [ self._training_docs = [
doc for doc in doc for doc in
self.data["train"] self.dataset["train"]
if doc["label"] if doc["label"]
] ]
return self._training_docs return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
......
"""
SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference
https://arxiv.org/pdf/1808.05326.pdf
SWAG (Situations With Adversarial Generations) is an adversarial dataset
that consists of 113k multiple choice questions about grounded situations. Each
question is a video caption from LSMDC or ActivityNet Captions, with four answer
choices about what might happen next in the scene. The correct answer is the
(real) video caption for the next event in the video; the three incorrect
answers are adversarially generated and human verified, so as to fool machines
but not humans.
Homepage: https://rowanzellers.com/swag/
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@inproceedings{zellers2018swagaf,
title={SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference},
author={Zellers, Rowan and Bisk, Yonatan and Schwartz, Roy and Choi, Yejin},
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
year={2018}
}
"""
class SWAG(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "swag"
DATASET_NAME = "regular"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = {
"query": doc["startphrase"],
"choices": [doc["ending0"], doc["ending1"], doc["ending2"], doc["ending3"]],
"gold": int(doc["label"]),
}
return out_doc
def doc_to_text(self, doc):
return doc["query"]
...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task): ...@@ -90,7 +90,7 @@ class GeneralTranslationTask(Task):
super().__init__() super().__init__()
def download(self): def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically # This caches in the users home dir automatically
self.src_file, self.ref_file = \ self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair) sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
...@@ -128,6 +128,12 @@ class GeneralTranslationTask(Task): ...@@ -128,6 +128,12 @@ class GeneralTranslationTask(Task):
tar_lang = code_to_language(language_codes[1]) tar_lang = code_to_language(language_codes[1])
return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:" return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test # This shows a single target, though there may be multiple targets in a lang test
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0] return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
......
...@@ -9,13 +9,10 @@ high quality distant supervision for answering the questions. ...@@ -9,13 +9,10 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import os import inspect
import json import lm_eval.datasets.triviaqa.triviaqa
import jsonlines
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from lm_eval.metrics import mean
from ..utils import sh
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -33,14 +30,8 @@ _CITATION = """ ...@@ -33,14 +30,8 @@ _CITATION = """
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 0 VERSION = 0
def download(self): DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'): DATASET_NAME = None
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", local_file="data/triviaqa/triviaqa-unfiltered.tar.gz", expected_checksum="adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh("""
cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -52,19 +43,25 @@ class TriviaQA(Task): ...@@ -52,19 +43,25 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl') return self.dataset['train']
def validation_docs(self): def validation_docs(self):
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl') return self.dataset['validation']
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['Question']}\nAnswer:" return f"Question: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['Answer']['Value'] return " " + doc['answer']['value']
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
...@@ -74,12 +71,11 @@ class TriviaQA(Task): ...@@ -74,12 +71,11 @@ class TriviaQA(Task):
for alias in aliases[1:]: for alias in aliases[1:]:
if not alias.startswith(ret[-1]): if not alias.startswith(ret[-1]):
ret.append(alias) ret.append(alias)
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']): for alias in self._remove_prefixes(doc['answer']['aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
......
...@@ -19,16 +19,14 @@ we could try this? ...@@ -19,16 +19,14 @@ we could try this?
Homepage: https://github.com/sylinrl/TruthfulQA Homepage: https://github.com/sylinrl/TruthfulQA
""" """
import csv import inspect
import json
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import datasets
import lm_eval.datasets.truthfulqa.truthfulqa
from rouge_score import rouge_scorer, scoring from rouge_score import rouge_scorer, scoring
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from pathlib import Path from lm_eval.metrics import mean
from best_download import download_file
from ..metrics import mean
from datasets import load_metric
_CITATION = """ _CITATION = """
...@@ -62,15 +60,8 @@ QA_PROMPT = ( ...@@ -62,15 +60,8 @@ QA_PROMPT = (
class TruthfulQAMultipleChoice(Task): class TruthfulQAMultipleChoice(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/mc') DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "multiple_choice"
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
mc_url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
checksum = "6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954"
download_file(mc_url, local_file=str(self.DATASET_PATH / "mc_task.json"), expected_checksum=checksum)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task): ...@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task):
raise NotImplementedError() raise NotImplementedError()
def validation_docs(self): def validation_docs(self):
with open(self.DATASET_PATH / "mc_task.json") as f: return self.dataset["validation"]
return json.load(f)
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -94,6 +84,12 @@ class TruthfulQAMultipleChoice(Task): ...@@ -94,6 +84,12 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return QA_PROMPT + "\n\nQ: " + doc['question'] + "\nA:" return QA_PROMPT + "\n\nQ: " + doc['question'] + "\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
...@@ -121,7 +117,7 @@ class TruthfulQAMultipleChoice(Task): ...@@ -121,7 +117,7 @@ class TruthfulQAMultipleChoice(Task):
return [rf.loglikelihood(ctx, " " + t)[0] for t in targets] return [rf.loglikelihood(ctx, " " + t)[0] for t in targets]
# MC1 and MC2 targets are not always the same set of strings so we collect # MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing. # likelihoods separately for simpler processing.
return get_lls(doc['mc1_targets']) + get_lls(doc['mc2_targets']) return get_lls(doc['mc1_targets']["choices"]) + get_lls(doc['mc2_targets']["choices"])
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -139,14 +135,14 @@ class TruthfulQAMultipleChoice(Task): ...@@ -139,14 +135,14 @@ class TruthfulQAMultipleChoice(Task):
def mc2(lls): def mc2(lls):
# Split on the first `0` as everything before it is true (`1`). # Split on the first `0` as everything before it is true (`1`).
split_idx = list(doc['mc2_targets'].values()).index(0) split_idx = list(doc['mc2_targets']["labels"]).index(0)
# Compute the normalized probability mass for the correct answer. # Compute the normalized probability mass for the correct answer.
ll_true, ll_false = lls[:split_idx], lls[split_idx:] ll_true, ll_false = lls[:split_idx], lls[split_idx:]
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false)) p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false)) p_true = p_true / (sum(p_true) + sum(p_false))
return sum(p_true) return sum(p_true)
split_idx = len(doc['mc1_targets']) split_idx = len(doc['mc1_targets']["choices"])
mc1_lls, mc2_lls = results[:split_idx], results[split_idx:] mc1_lls, mc2_lls = results[:split_idx], results[split_idx:]
return { return {
"mc1": mc1(mc1_lls), "mc1": mc1(mc1_lls),
...@@ -168,19 +164,12 @@ class TruthfulQAMultipleChoice(Task): ...@@ -168,19 +164,12 @@ class TruthfulQAMultipleChoice(Task):
class TruthfulQAGeneration(Task): class TruthfulQAGeneration(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/generation') DATASET_PATH = inspect.getfile(lm_eval.datasets.truthfulqa.truthfulqa)
DATASET_NAME = "generation"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.bleurt = load_metric("bleurt", cache_dir="lm_cache") self.bleurt = datasets.load_metric("bleurt")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
checksum = "8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2"
download_file(url, local_file=str(self.DATASET_PATH / "TruthfulQA.csv"), expected_checksum=checksum)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -194,36 +183,29 @@ class TruthfulQAGeneration(Task): ...@@ -194,36 +183,29 @@ class TruthfulQAGeneration(Task):
def training_docs(self): def training_docs(self):
raise NotImplementedError() raise NotImplementedError()
def _split_multi_answer(self, answers, sep=';'): def _format_answers(self, answers):
answers = answers.strip().split(sep) formatted_answers = []
split_answers = []
for answer in answers: for answer in answers:
answer = answer.strip() answer = answer.strip()
if len(answer): if len(answer):
# Add a period after all answers. # Add a period after all answers.
if answer[-1] != '.': if answer[-1] != '.':
split_answers.append(answer + '.') formatted_answers.append(answer + '.')
else: else:
split_answers.append(answer) formatted_answers.append(answer)
return split_answers return formatted_answers
def validation_docs(self): def validation_docs(self):
with open(self.DATASET_PATH / "TruthfulQA.csv", newline='') as csvfile: for doc in self.dataset["validation"]:
doc_reader = csv.DictReader(csvfile) incorrect_answers = self._format_answers(doc['incorrect_answers'])
for doc in doc_reader: correct_answers = self._format_answers(doc['correct_answers'])
# Ensure that references exist. if "I have no comment." not in correct_answers:
if not doc['Correct Answers'] or not doc['Incorrect Answers']: correct_answers.append("I have no comment.")
continue yield {
correct_answers = self._split_multi_answer(doc['Correct Answers']) 'question': doc['question'].strip(),
if "I have no comment." not in correct_answers: 'correct_answers': correct_answers,
correct_answers.append("I have no comment.") 'incorrect_answers': incorrect_answers
incorrect_answers = self._split_multi_answer(doc['Incorrect Answers']) }
doc = {
'question': doc['Question'].strip(),
'correct_answers': correct_answers,
'incorrect_answers': incorrect_answers
}
yield doc
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word. ...@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word.
Homepage: https://github.com/openai/gpt-3/tree/master/data Homepage: https://github.com/openai/gpt-3/tree/master/data
""" """
import gzip import inspect
import json import lm_eval.datasets.unscramble.unscramble
import shutil
from pathlib import Path
from best_download import download_file
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -32,30 +29,10 @@ _CITATION = """ ...@@ -32,30 +29,10 @@ _CITATION = """
""" """
def extract_gzip(gz, to):
with gzip.open(gz, 'rb') as fin:
with open(to, 'wb') as fout:
shutil.copyfileobj(fin, fout)
class WordUnscrambleTask(Task): class WordUnscrambleTask(Task):
VERSION = 0 VERSION = 0
BASE_PATH = Path("data/unscramble") DATASET_PATH = inspect.getfile(lm_eval.datasets.unscramble.unscramble)
FILENAME = None DATASET_NAME = None
CHECKSUM = None # SHA256 Checksum.
def __init__(self):
super().__init__()
def download(self):
if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME
if not file.exists():
rawfile = file.parent / (file.name + ".gz")
base_url = "https://raw.githubusercontent.com/openai/gpt-3/master/data"
download_file(f"{base_url}/{self.FILENAME}.gz", local_file=str(rawfile), expected_checksum=self.CHECKSUM)
extract_gzip(gz=rawfile, to=file)
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -67,12 +44,17 @@ class WordUnscrambleTask(Task): ...@@ -67,12 +44,17 @@ class WordUnscrambleTask(Task):
return False return False
def validation_docs(self): def validation_docs(self):
file = self.BASE_PATH / self.FILENAME return self.dataset["validation"]
return (json.loads(line) for line in open(file).read().splitlines())
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
...@@ -99,25 +81,20 @@ class WordUnscrambleTask(Task): ...@@ -99,25 +81,20 @@ class WordUnscrambleTask(Task):
class Anagrams1(WordUnscrambleTask): class Anagrams1(WordUnscrambleTask):
FILENAME = "mid_word_1_anagrams.jsonl" DATASET_NAME = "mid_word_1_anagrams"
CHECKSUM = "6768a86896083199de4815d4964cb2f6f1046476cfd80c2a562784f182905979"
class Anagrams2(WordUnscrambleTask): class Anagrams2(WordUnscrambleTask):
FILENAME = "mid_word_2_anagrams.jsonl" DATASET_NAME = "mid_word_2_anagrams"
CHECKSUM = "c3d839d09a7954b78a27cd2cd75d4ed0488656c56ef4dbd741a005343826cb01"
class CycleLetters(WordUnscrambleTask): class CycleLetters(WordUnscrambleTask):
FILENAME = "cycle_letters_in_word.jsonl" DATASET_NAME = "cycle_letters_in_word"
CHECKSUM = "1689c9002bb8c5988bf5f05e977c9db92f57932c1b5a38998c29ac0dd71e1d42"
class RandomInsertion(WordUnscrambleTask): class RandomInsertion(WordUnscrambleTask):
FILENAME = "random_insertion_in_word.jsonl" DATASET_NAME = "random_insertion_in_word"
CHECKSUM = "72e65d83da53d15752ee0c47379509de149ddbad32d61184e5991df29616b78a"
class ReversedWords(WordUnscrambleTask): class ReversedWords(WordUnscrambleTask):
FILENAME = "reversed_words.jsonl" DATASET_NAME = "reversed_words"
CHECKSUM = "133a08f875cd6c1ef8608a3233571a773881cc27b1c707de738cc6543439332a"
...@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013). ...@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013).
Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a
""" """
from . common import HFTask from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
_CITATION = """ _CITATION = """
...@@ -32,7 +31,7 @@ _CITATION = """ ...@@ -32,7 +31,7 @@ _CITATION = """
""" """
class WebQs(HFTask): class WebQs(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
DATASET_NAME = None DATASET_NAME = None
...@@ -46,9 +45,23 @@ class WebQs(HFTask): ...@@ -46,9 +45,23 @@ class WebQs(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes # this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible. # multiple correct answers being possible.
......
...@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import os
import re import re
from lm_eval.base import rf, PerplexityTask import inspect
from lm_eval.utils import sh import lm_eval.datasets.wikitext.wikitext
from best_download import download_file from lm_eval.base import PerplexityTask
_CITATION = """ _CITATION = """
...@@ -64,45 +63,36 @@ def wikitext_detokenizer(string): ...@@ -64,45 +63,36 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_NAME = "wikitext-2-raw-v1"
def download(self): def has_training_docs(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", local_file="data/wikitext/wikitext-2-raw-v1.zip", expected_checksum="ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def has_validation_docs(self):
return True return True
def has_train_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return True
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self): def training_docs(self):
return self.docs_for_split('valid') return map(self._process_doc, self.dataset["train"])
def train_docs(self): def validation_docs(self):
return self.docs_for_split('train') return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return self.docs_for_split('test') return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
...@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847 ...@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
""" """
import numpy as np import numpy as np
from . common import HFTask from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
_CITATION = """ _CITATION = """
...@@ -30,7 +29,7 @@ _CITATION = """ ...@@ -30,7 +29,7 @@ _CITATION = """
""" """
class Winogrande(HFTask): class Winogrande(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
...@@ -46,9 +45,23 @@ class Winogrande(HFTask): ...@@ -46,9 +45,23 @@ class Winogrande(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option # Substitute the pronoun in the sentence with the specified option
......
...@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0 ...@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
""" """
import numpy as np import numpy as np
import random from lm_eval.base import rf, Task
from lm_eval.base import rf from lm_eval.metrics import mean
from ..metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -37,7 +35,7 @@ _CITATION = """ ...@@ -37,7 +35,7 @@ _CITATION = """
""" """
class WinogradSchemaChallenge273(HFTask): class WinogradSchemaChallenge273(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
...@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
upper_pronouns = ["A", "An", "The", "She", "He", upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"] "It", "They", "My", "His", "Her", "Their"]
def __init__(self): def has_training_docs(self):
super().__init__() return False
self.data = self.__clean_data()
def has_validation_docs(self):
return False
def __clean_data(self): def has_test_docs(self):
return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly. # The HF implementation of `wsc273` is not `partial evaluation` friendly.
data = [] doc["text"] = doc["text"].replace(" ", " ")
for doc in self.data["test"]: doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["text"] = doc["text"].replace(" ", " ") doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
doc["options"][0] = self.__normalize_option(doc, doc["options"][0]) return doc
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, doc, option): def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options. # Append `'s` to possessive determiner based options.
...@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
return option.replace(pronoun, pronoun.lower()) return option.replace(pronoun, pronoun.lower())
return option return option
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are # NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset. # not available for this test-set-only dataset.
...@@ -91,6 +85,12 @@ class WinogradSchemaChallenge273(HFTask): ...@@ -91,6 +85,12 @@ class WinogradSchemaChallenge273(HFTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]]) return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified # Substitute the pronoun in the original text with the specified
......
import argparse import argparse
import json import json
import logging import logging
import fnmatch
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
...@@ -19,22 +35,35 @@ def parse_args(): ...@@ -19,22 +35,35 @@ def parse_args():
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument('--decontamination_ngrams_path', default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
return parser.parse_args() return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.tasks == "all_tasks": if args.tasks is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
...@@ -51,11 +80,11 @@ def main(): ...@@ -51,11 +80,11 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity check_integrity=args.check_integrity
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
\ No newline at end of file
import glob
import argparse
import os
import subprocess
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
def process_task(working_directory, output_directory, bucket_file_path, tqdm_func, global_tqdm):
command = f"zstd {bucket_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
compressed_file = bucket_file_path + ".zst"
if output_directory:
shutil.move(compressed_file, output_directory)
os.remove(bucket_file_path)
global_tqdm.update()
def compress_and_move(working_directory, output_directory, process_count):
os.makedirs(output_directory, exist_ok=True)
original_info_file_path = os.path.join(working_directory, "info.json")
assert(os.path.exists(original_info_file_path))
tasks = []
bucket_file_paths = glob.glob(os.path.join(working_directory, "output", f"*.bkt.txt.sorted"))
for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path))
tasks.append(task)
pool = TqdmMultiProcessPool(process_count)
on_done = lambda _ : None
on_error = lambda _ : None
global_progress = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="file")
_ = pool.map(global_progress, tasks, on_error, on_done)
shutil.copy(original_info_file_path, os.path.join(output_directory, "info.json"))
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser.add_argument("-dir", "--working_directory", required=True)
parser.add_argument("-output", "--output_directory", required=True)
parser.add_argument("-procs", "--process_count", type=int, default=8)
if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
logfile_path = "compress_and_package.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
compress_and_move(args.working_directory, args.output_directory, args.process_count)
\ No newline at end of file
...@@ -21,8 +21,10 @@ Arguments ...@@ -21,8 +21,10 @@ Arguments
""" """
import argparse import argparse
import json
import pickle import pickle
import os import os
import sys
from pathlib import Path from pathlib import Path
import glob import glob
import signal import signal
...@@ -30,32 +32,89 @@ from signal import SIGINT ...@@ -30,32 +32,89 @@ from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader from lm_eval.decontamination.archiver import TextArchive, Reader
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
pile_document_count = 210607728
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory): def yield_pile(start_offsets=None, checkpoint_offset=None):
reader = Reader() directory = "pile"
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file): if not os.path.exists(directory):
yield document print("We expect the pile archives to be in the 'pile' directory, but this was not found.")
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
def close_buckets(buckets): start_file = file_i
for bucket in buckets: pile_global_offset = start_offset
bucket.commit()
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file):
yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
...@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt") checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb")) checkpoint_offset = pickle.load(open(checkpoint_file,"rb"))
iterate = True
else: else:
start_id = 0 checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}") logger.info(f"Starting at pile document index {checkpoint_offset}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)] buckets = Buckets(output_directory, bucket_count)
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor() janitor = Janitor()
current_id = 0
batch_size = 1000 batch_size = 1000
batch_counter = 0 batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1 with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update() progress.update()
if terminate:
return
continue continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
batch_counter = 0 batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb")) buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file,"wb"))
if terminate: if terminate:
close_buckets(buckets) buckets.close_buckets()
return return
ngrams = word_ngrams(janitor.normalize_string(document), n_value) ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams: for ngram in ngrams:
bucket = hash(ngram) % len(buckets) buckets.add_data(ngram, f"{ngram} {offset}")
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1 batch_counter += 1
current_id += 1
close_buckets(buckets) buckets.close_buckets()
Path(done_file).touch() Path(done_file).touch()
...@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13) ...@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__': if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -128,4 +200,8 @@ if __name__ == '__main__': ...@@ -128,4 +200,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment