Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
""" """
QuAC: Question Answering in Context QuAC: Question Answering in Context
https://arxiv.org/abs/1808.07036 https://arxiv.org/abs/1808.07036
Question Answering in Context (QuAC) is a dataset for modeling, understanding, and Question Answering in Context (QuAC) is a dataset for modeling, understanding, and
participating in information seeking dialog. Data instances consist of an interactive participating in information seeking dialog. Data instances consist of an interactive
dialog between two crowd workers: (1) a student who poses a sequence of freeform dialog between two crowd workers: (1) a student who poses a sequence of freeform
questions to learn as much as possible about a hidden Wikipedia text, and (2) questions to learn as much as possible about a hidden Wikipedia text, and (2)
a teacher who answers the questions by providing short excerpts (spans) from the text. a teacher who answers the questions by providing short excerpts (spans) from the text.
Homepage: https://quac.ai/ Homepage: https://quac.ai/
""" """
import inspect import inspect
import lm_eval.datasets.quac.quac import lm_eval.datasets.quac.quac
from lm_eval.base import Task from lm_eval.base import Task
_CITATION = """ _CITATION = """
@article{choi2018quac, @article{choi2018quac,
title={Quac: Question answering in context}, title={Quac: Question answering in context},
author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke}, author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:1808.07036}, journal={arXiv preprint arXiv:1808.07036},
year={2018} year={2018}
} }
""" """
class QuAC(Task): class QuAC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = inspect.getfile(lm_eval.datasets.quac.quac) DATASET_PATH = inspect.getfile(lm_eval.datasets.quac.quac)
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._process_doc, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def _process_doc(self, doc): def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title'] doc["title"] = doc["title"] + " - " + doc["section_title"]
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return (
"TITLE: "
def doc_to_target(self, doc): + doc["title"]
return doc['answer'] + "\n"
+ "PARAGRAPH: "
def construct_requests(self, doc, ctx): + doc["paragraph"]
""" Uses RequestFactory to construct Requests and returns an iterable of + "\n\n"
Requests which will be sent to the LM. + "Q: "
+ doc["question"]
:param doc: + "\n\n"
The document as returned from training_docs, validation_docs, or test_docs. + "A: "
:param ctx: str )
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question def should_decontaminate(self):
part of the document for `doc`. return True
"""
# TODO: implement evaluation. def doc_to_decontamination_query(self, doc):
raise NotImplementedError('Evaluation not implemented') return doc["paragraph"]
def process_results(self, doc, results): def doc_to_target(self, doc):
"""Take a single document and the LM results and evaluates, returning a return doc["answer"]
dict where keys are the names of submetrics and values are the values of
the metric for that one document def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
:param doc: Requests which will be sent to the LM.
The document as returned from training_docs, validation_docs, or test_docs.
:param results: :param doc:
The results of the requests created in construct_requests. The document as returned from training_docs, validation_docs, or test_docs.
""" :param ctx: str
# TODO: implement evaluation. The context string, generated by fewshot_context. This includes the natural
raise NotImplementedError('Evaluation not implemented') language description, as well as the few shot examples, and the question
part of the document for `doc`.
def aggregation(self): """
""" # TODO: implement evaluation.
:returns: {str: [float] -> float} raise NotImplementedError("Evaluation not implemented")
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def process_results(self, doc, results):
""" """Take a single document and the LM results and evaluates, returning a
# TODO: implement evaluation. dict where keys are the names of submetrics and values are the values of
raise NotImplementedError('Evaluation not implemented') the metric for that one document
def higher_is_better(self): :param doc:
""" The document as returned from training_docs, validation_docs, or test_docs.
:returns: {str: bool} :param results:
A dictionary where keys are the names of submetrics and values are The results of the requests created in construct_requests.
whether a higher value of the submetric is better """
""" # TODO: implement evaluation.
# TODO: implement evaluation. raise NotImplementedError("Evaluation not implemented")
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError("Evaluation not implemented")
...@@ -20,7 +20,7 @@ _CITATION = """ ...@@ -20,7 +20,7 @@ _CITATION = """
@article{lai2017large, @article{lai2017large,
title={RACE: Large-scale ReAding Comprehension Dataset From Examinations}, title={RACE: Large-scale ReAding Comprehension Dataset From Examinations},
author={Lai, Guokun and Xie, Qizhe and Liu, Hanxiao and Yang, Yiming and Hovy, Eduard}, author={Lai, Guokun and Xie, Qizhe and Liu, Hanxiao and Yang, Yiming and Hovy, Eduard},
journal={arXiv preprint arXiv:1704.04683}, journal={arXiv preprint arXiv:1704.04683},
year={2017} year={2017}
} }
""" """
...@@ -40,7 +40,7 @@ class RACE(Task): ...@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -59,17 +59,27 @@ class RACE(Task): ...@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage. # is shown that one document is made per passage.
r = collections.defaultdict(list) r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: for item in datasets.load_dataset(
r[item['article']].append(item) path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
res = list(r.values() >> each(lambda x: { r[item["article"]].append(item)
'article': x[0]['article'],
'problems': x >> each(lambda y: { res = list(
'question': y['question'], r.values()
'answer': y['answer'], >> each(
'options': y['options'], lambda x: {
}) "article": x[0]["article"],
})) "problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res self.cache[set] = res
return res return res
...@@ -85,49 +95,56 @@ class RACE(Task): ...@@ -85,49 +95,56 @@ class RACE(Task):
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem["answer"]]
return problem['options'][answer] return problem["options"][answer]
@classmethod @classmethod
def last_problem(cls, doc): def last_problem(cls, doc):
return doc['problems'][-1] return doc["problems"][-1]
def doc_to_text(self, doc): def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n' text = "Article: " + doc["article"] + "\n\n"
for problem in doc['problems'][:-1]: for problem in doc["problems"][:-1]:
if problem['question'][-6:] == ' _ .': if problem["question"][-6:] == " _ .":
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' text += (
problem["question"][-5:] + self.get_answer_option(problem) + "\n"
)
else: else:
question = 'Question: ' + problem['question'] + '\n' question = "Question: " + problem["question"] + "\n"
answer = 'Answer: ' + self.get_answer_option(problem) + '\n' answer = "Answer: " + self.get_answer_option(problem) + "\n"
text += question + answer text += question + answer
text += self.last_problem(doc)['question'] text += self.last_problem(doc)["question"]
return text return text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc)) return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
problem = self.last_problem(doc) problem = self.last_problem(doc)
ll_choices = [ ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0] rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for i in range(4)
] ]
return ll_choices return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -135,28 +152,22 @@ class RACE(Task): ...@@ -135,28 +152,22 @@ class RACE(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = self.letter_to_num[self.last_problem(doc)['answer']] gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask):
def _process_doc(self, doc): def _process_doc(self, doc):
return { return {
'source': doc['source'], "source": doc["source"],
'query': doc['stem'].split(' ')[:2], "query": doc["stem"].split(" ")[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]], "choices": [
'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()), "{} is to {}".format(*c.split(" ")[:2]) for c in doc["choices"]
],
"gold": ["a", "b", "c", "d", "e"].index(doc["solution"].strip()),
} }
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc["query"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + "\n" + " ".join(doc["query"])
...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask): ...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
doc["distractor3"], doc["distractor3"],
doc["correct_answer"], doc["correct_answer"],
] ]
src = doc['support'] src = doc["support"]
out_doc = { out_doc = {
"source": src, "source": src,
"query": doc['question'], "query": doc["question"],
"choices": choices, "choices": choices,
"gold": 3, "gold": 3,
} }
...@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask): ...@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip() return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
""" """
Know What You Don’t Know: Unanswerable Questions for SQuAD Know What You Don’t Know: Unanswerable Questions for SQuAD
https://arxiv.org/pdf/1806.03822.pdf https://arxiv.org/pdf/1806.03822.pdf
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
consisting of questions posed by crowdworkers on a set of Wikipedia articles, consisting of questions posed by crowdworkers on a set of Wikipedia articles,
where the answer to every question is a segment of text, or span, from the where the answer to every question is a segment of text, or span, from the
corresponding reading passage, or the question might be unanswerable. corresponding reading passage, or the question might be unanswerable.
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
questions written adversarially by crowdworkers to look similar to answerable ones. questions written adversarially by crowdworkers to look similar to answerable ones.
To do well on SQuAD2.0, systems must not only answer questions when possible, but To do well on SQuAD2.0, systems must not only answer questions when possible, but
also determine when no answer is supported by the paragraph and abstain from answering. also determine when no answer is supported by the paragraph and abstain from answering.
Homepage: https://rajpurkar.github.io/SQuAD-explorer/ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets import datasets
from math import exp from math import exp
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from functools import partial from functools import partial
from packaging import version from packaging import version
_CITATION = """ _CITATION = """
@misc{rajpurkar2018know, @misc{rajpurkar2018know,
title={Know What You Don't Know: Unanswerable Questions for SQuAD}, title={Know What You Don't Know: Unanswerable Questions for SQuAD},
author={Pranav Rajpurkar and Robin Jia and Percy Liang}, author={Pranav Rajpurkar and Robin Jia and Percy Liang},
year={2018}, year={2018},
eprint={1806.03822}, eprint={1806.03822},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.CL} primaryClass={cs.CL}
} }
""" """
def _squad_metric(predictions, references): def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2") squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references) return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items): def _squad_agg(key, items):
predictions, references = zip(*items) predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key] return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(Task): class SQuAD2(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
DATASET_NAME = None DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one # HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD" assert version.parse(datasets.__version__) >= version.parse(
"1.11.0"
def has_training_docs(self): ), "datasets v1.11.0 or later required for SQuAD"
return True
def has_training_docs(self):
def has_validation_docs(self): return True
return True
def has_validation_docs(self):
def has_test_docs(self): return True
return False
def has_test_docs(self):
def training_docs(self): return False
return self.dataset["train"]
def training_docs(self):
def validation_docs(self): return self.dataset["train"]
return self.dataset["validation"]
def validation_docs(self):
def doc_to_text(self, doc): return self.dataset["validation"]
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def doc_to_text(self, doc):
def doc_to_target(self, doc): return (
answer_list = doc['answers']['text'] "Title: "
if len(answer_list) > 0: + doc["title"]
answer = answer_list[0] + "\n\n"
else: + "Background: "
answer = 'unanswerable' + doc["context"]
return " " + answer + "\n\n"
+ "Question: "
def construct_requests(self, doc, ctx): + doc["question"]
""" Uses RequestFactory to construct Requests and returns an iterable of + "\n\n"
Requests which will be sent to the LM. + "Answer:"
)
:param doc:
The document as returned from training_docs, validation_docs, or test_docs. def should_decontaminate(self):
:param ctx: str return True
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question def doc_to_decontamination_query(self, doc):
part of the document for `doc`. return doc["context"]
"""
continuation = rf.greedy_until(ctx, ['\n']) def doc_to_target(self, doc):
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") answer_list = doc["answers"]["text"]
return continuation, is_unanswerable if len(answer_list) > 0:
answer = answer_list[0]
def process_results(self, doc, results): else:
"""Take a single document and the LM results and evaluates, returning a answer = "unanswerable"
dict where keys are the names of submetrics and values are the values of return " " + answer
the metric for that one document
def construct_requests(self, doc, ctx):
:param doc: """Uses RequestFactory to construct Requests and returns an iterable of
The document as returned from training_docs, validation_docs, or test_docs. Requests which will be sent to the LM.
:param results:
The results of the requests created in construct_requests. :param doc:
""" The document as returned from training_docs, validation_docs, or test_docs.
continuation, (logprob_unanswerable, _) = results :param ctx: str
The context string, generated by fewshot_context. This includes the natural
no_answer_probability = exp(logprob_unanswerable) language description, as well as the few shot examples, and the question
part of the document for `doc`.
predictions = { """
'id': doc['id'], continuation = rf.greedy_until(ctx, ["\n"])
'prediction_text': continuation, is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
'no_answer_probability': no_answer_probability, return continuation, is_unanswerable
}
def process_results(self, doc, results):
references = { """Take a single document and the LM results and evaluates, returning a
'id': doc['id'], dict where keys are the names of submetrics and values are the values of
'answers': doc['answers'], the metric for that one document
}
:param doc:
return { The document as returned from training_docs, validation_docs, or test_docs.
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) :param results:
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer The results of the requests created in construct_requests.
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) """
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer continuation, (logprob_unanswerable, _) = results
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer no_answer_probability = exp(logprob_unanswerable)
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold) predictions = {
} "id": doc["id"],
"prediction_text": continuation,
def aggregation(self): "no_answer_probability": no_answer_probability,
""" }
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are references = {
functions that aggregate a list of metrics "id": doc["id"],
""" "answers": doc["answers"],
return { }
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer return {
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer) "exact": (
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer predictions,
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer) references,
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer ), # Exact match (the normalized answer exactly match the gold answer)
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold) "f1": (
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold) predictions,
} references,
), # The F-score of predicted tokens versus the gold answer
def higher_is_better(self): "HasAns_exact": (
""" predictions,
:returns: {str: bool} references,
A dictionary where keys are the names of submetrics and values are ), # Exact match (the normalized answer exactly match the gold answer)
whether a higher value of the submetric is better "HasAns_f1": (
""" predictions,
return { references,
'exact': True, # Exact match (the normalized answer exactly match the gold answer) ), # The F-score of predicted tokens versus the gold answer
'f1': True, # The F-score of predicted tokens versus the gold answer "NoAns_exact": (
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer) predictions,
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer references,
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer) ), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer "NoAns_f1": (
'best_exact': True, # Best exact match (with varying threshold) predictions,
'best_f1': True, # Best F1 (with varying threshold) references,
} ), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}
...@@ -65,12 +65,27 @@ class StoryCloze(Task): ...@@ -65,12 +65,27 @@ class StoryCloze(Task):
return self.dataset["test"] return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([ return " ".join(
doc["input_sentence_1"], [
doc["input_sentence_2"], doc["input_sentence_1"],
doc["input_sentence_3"], doc["input_sentence_2"],
doc["input_sentence_4"], doc["input_sentence_3"],
]) doc["input_sentence_4"],
]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return " ".join(
[
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
]
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
...@@ -78,7 +93,7 @@ class StoryCloze(Task): ...@@ -78,7 +93,7 @@ class StoryCloze(Task):
return " " + clozes[doc["answer_right_ending"] - 1] return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -89,10 +104,7 @@ class StoryCloze(Task): ...@@ -89,10 +104,7 @@ class StoryCloze(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
lls = [ lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in clozes]
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
return lls return lls
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -106,10 +118,8 @@ class StoryCloze(Task): ...@@ -106,10 +118,8 @@ class StoryCloze(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = doc["answer_right_ending"] - 1 gold = doc["answer_right_ending"] - 1
acc = 1. if np.argmax(results) == gold else 0. acc = 1.0 if np.argmax(results) == gold else 0.0
return { return {"acc": acc}
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
...@@ -117,9 +127,7 @@ class StoryCloze(Task): ...@@ -117,9 +127,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -127,9 +135,7 @@ class StoryCloze(Task): ...@@ -127,9 +135,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class StoryCloze2016(StoryCloze): class StoryCloze2016(StoryCloze):
......
...@@ -56,14 +56,20 @@ class BoolQ(Task): ...@@ -56,14 +56,20 @@ class BoolQ(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + yesno(doc['label']) return " " + yesno(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes') ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, ' no') ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no return ll_yes, ll_no
...@@ -71,21 +77,15 @@ class BoolQ(Task): ...@@ -71,21 +77,15 @@ class BoolQ(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
gold = doc["label"] gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0. acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return {"acc": acc}
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class CommitmentBank(Task): class CommitmentBank(Task):
...@@ -123,27 +123,21 @@ class CommitmentBank(Task): ...@@ -123,27 +123,21 @@ class CommitmentBank(Task):
return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]]) return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' True') ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, ' False') ll_false, _ = rf.loglikelihood(ctx, " False")
ll_neither, _ = rf.loglikelihood(ctx, ' Neither') ll_neither, _ = rf.loglikelihood(ctx, " Neither")
return ll_true, ll_false, ll_neither return ll_true, ll_false, ll_neither
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
acc = 1. if pred == gold else 0. acc = 1.0 if pred == gold else 0.0
return {"acc": acc, "f1": (pred, gold)}
return {
"acc": acc,
"f1": (pred, gold)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
@classmethod @classmethod
def cb_multi_fi(cls, items): def cb_multi_fi(cls, items):
...@@ -155,7 +149,7 @@ class CommitmentBank(Task): ...@@ -155,7 +149,7 @@ class CommitmentBank(Task):
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = mean([f11, f12, f13]) avg_f1 = mean([f11, f12, f13])
return avg_f1 return avg_f1
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "acc": mean,
...@@ -201,7 +195,7 @@ class Copa(Task): ...@@ -201,7 +195,7 @@ class Copa(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"]) choice1 = " " + self.convert_choice(doc["choice1"])
choice2 = " " + self.convert_choice(doc["choice2"]) choice2 = " " + self.convert_choice(doc["choice2"])
ll_choice1, _ = rf.loglikelihood(ctx, choice1) ll_choice1, _ = rf.loglikelihood(ctx, choice1)
ll_choice2, _ = rf.loglikelihood(ctx, choice2) ll_choice2, _ = rf.loglikelihood(ctx, choice2)
...@@ -210,21 +204,15 @@ class Copa(Task): ...@@ -210,21 +204,15 @@ class Copa(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
acc = 1. if pred == gold else 0. acc = 1.0 if pred == gold else 0.0
return {"acc": acc}
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
@staticmethod @staticmethod
def convert_choice(choice): def convert_choice(choice):
...@@ -267,28 +255,22 @@ class MultiRC(Task): ...@@ -267,28 +255,22 @@ class MultiRC(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
true_choice = self.format_answer(answer=doc["answer"], label=True) true_choice = self.format_answer(answer=doc["answer"], label=True)
false_choice = self.format_answer(answer=doc["answer"], label=False) false_choice = self.format_answer(answer=doc["answer"], label=False)
ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}') ll_true_choice, _ = rf.loglikelihood(ctx, f" {true_choice}")
ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}') ll_false_choice, _ = rf.loglikelihood(ctx, f" {false_choice}")
return ll_true_choice, ll_false_choice return ll_true_choice, ll_false_choice
def process_results(self, doc, results): def process_results(self, doc, results):
ll_true_choice, ll_false_choice = results ll_true_choice, ll_false_choice = results
pred = ll_true_choice > ll_false_choice pred = ll_true_choice > ll_false_choice
return { return {"acc": (pred, doc)}
"acc": (pred, doc)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": acc_all}
"acc": acc_all
}
class ReCoRD(Task): class ReCoRD(Task):
...@@ -337,7 +319,7 @@ class ReCoRD(Task): ...@@ -337,7 +319,7 @@ class ReCoRD(Task):
@classmethod @classmethod
def format_answer(cls, query, entity): def format_answer(cls, query, entity):
return f' - {query}'.replace("@placeholder", entity) return f" - {query}".replace("@placeholder", entity)
def doc_to_target(self, doc): def doc_to_target(self, doc):
# We only output the first correct entity in a doc # We only output the first correct entity in a doc
...@@ -359,8 +341,12 @@ class ReCoRD(Task): ...@@ -359,8 +341,12 @@ class ReCoRD(Task):
prediction = doc["entities"][max_idx] prediction = doc["entities"][max_idx]
gold_label_set = doc["answers"] gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) f1 = metric_max_over_ground_truths(
em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set) squad_metrics.compute_f1, prediction, gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, prediction, gold_label_set
)
return { return {
"f1": f1, "f1": f1,
...@@ -403,19 +389,21 @@ class WordsInContext(Task): ...@@ -403,19 +389,21 @@ class WordsInContext(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ return (
" two sentences above?\nAnswer:".format( "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the"
doc["sentence1"], " two sentences above?\nAnswer:".format(
doc["sentence2"], doc["sentence1"],
doc["sentence1"][doc["start1"]:doc["end1"]], doc["sentence2"],
) doc["sentence1"][doc["start1"] : doc["end1"]],
)
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format({0: "no", 1: "yes"}[doc["label"]]) return " {}".format({0: "no", 1: "yes"}[doc["label"]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes') ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, ' no') ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no return ll_yes, ll_no
...@@ -423,21 +411,15 @@ class WordsInContext(Task): ...@@ -423,21 +411,15 @@ class WordsInContext(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
gold = doc["label"] gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0. acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return { return {"acc": acc}
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class SGWinogradSchemaChallenge(Task): class SGWinogradSchemaChallenge(Task):
...@@ -461,9 +443,7 @@ class SGWinogradSchemaChallenge(Task): ...@@ -461,9 +443,7 @@ class SGWinogradSchemaChallenge(Task):
if self._training_docs is None: if self._training_docs is None:
# GPT-3 Paper's format only uses positive examples for fewshot "training" # GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [ self._training_docs = [
doc for doc in doc for doc in self.dataset["train"] if doc["label"]
self.dataset["train"]
if doc["label"]
] ]
return self._training_docs return self._training_docs
...@@ -473,25 +453,25 @@ class SGWinogradSchemaChallenge(Task): ...@@ -473,25 +453,25 @@ class SGWinogradSchemaChallenge(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[:doc["span2_index"]]) pre = " ".join(raw_passage.split()[: doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:] post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :]
passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post) passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post)
noun = doc["span1_text"] noun = doc["span1_text"]
pronoun = doc["span2_text"] pronoun = doc["span2_text"]
text = ( text = (
f"Passage: {passage}\n" f"Passage: {passage}\n"
+ f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n" + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
+ "Answer:" + "Answer:"
) )
return text return text
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + yesno(doc['label']) return " " + yesno(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes') ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, ' no') ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no return ll_yes, ll_no
...@@ -499,18 +479,12 @@ class SGWinogradSchemaChallenge(Task): ...@@ -499,18 +479,12 @@ class SGWinogradSchemaChallenge(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
gold = doc["label"] gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0. acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return { return {"acc": acc}
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
...@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict): ...@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
def version_of(dataset, language_pair): def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]: if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa return 1 # changed to use jieba/nagisa
return 0 return 0
return { return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair)) f"{dataset}-{language_pair}": create_translation_task(
dataset, language_pair, version_of(dataset, language_pair)
)
for dataset, language_pairs in benchmark_dict.items() for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs for language_pair in language_pairs
} }
######################################## ########################################
# Language Specifics # Language Specifics
######################################## ########################################
def zh_split(zh_text: List[str]) -> List[str]: def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting""" """Chinese splitting"""
import jieba import jieba
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text] return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]: def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting""" """Japanese splitting"""
import nagisa import nagisa
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text] return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split} NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
######################################## ########################################
# Tasks # Tasks
######################################## ########################################
def create_translation_task(dataset, language_pair, version=0): def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
VERSION = version VERSION = version
def __init__(self): def __init__(self):
super().__init__(dataset, language_pair) super().__init__(dataset, language_pair)
return TranslationTask return TranslationTask
class GeneralTranslationTask(Task): class GeneralTranslationTask(Task):
VERSION = 0 VERSION = 0
...@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task): ...@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task):
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically # This caches in the users home dir automatically
self.src_file, self.ref_file = \ self.src_file, self.ref_file = sacrebleu.download_test_set(
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair) self.sacrebleu_dataset, self.sacrebleu_language_pair
)
self.src_data, self.ref_data = [ self.src_data, self.ref_data = [
[line.rstrip() for line in sacrebleu.smart_open(file)] [line.rstrip() for line in sacrebleu.smart_open(file)]
for file in (self.src_file, self.ref_file) for file in (self.src_file, self.ref_file)
...@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task): ...@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task):
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [{ return [
"src": src, {"src": src, "ref": ref} for src, ref in zip(self.src_data, self.ref_data)
"ref": ref ]
} for src, ref in zip(self.src_data, self.ref_data)]
def doc_to_text(self, doc): def doc_to_text(self, doc):
language_codes = self.sacrebleu_language_pair.split("-") language_codes = self.sacrebleu_language_pair.split("-")
...@@ -128,12 +141,18 @@ class GeneralTranslationTask(Task): ...@@ -128,12 +141,18 @@ class GeneralTranslationTask(Task):
tar_lang = code_to_language(language_codes[1]) tar_lang = code_to_language(language_codes[1])
return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:" return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test # This shows a single target, though there may be multiple targets in a lang test
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0] return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
......
...@@ -43,10 +43,10 @@ class TriviaQA(Task): ...@@ -43,10 +43,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return self.dataset['train'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.dataset['validation'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -54,8 +54,14 @@ class TriviaQA(Task): ...@@ -54,8 +54,14 @@ class TriviaQA(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['question']}\nAnswer:" return f"Question: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['answer']['value'] return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
...@@ -69,15 +75,13 @@ class TriviaQA(Task): ...@@ -69,15 +75,13 @@ class TriviaQA(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answer']['aliases']): for alias in self._remove_prefixes(doc["answer"]["aliases"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
def process_results(self, doc, results): def process_results(self, doc, results):
return { return {"acc": float(any(results))}
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -85,6 +89,4 @@ class TriviaQA(Task): ...@@ -85,6 +89,4 @@ class TriviaQA(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
...@@ -80,22 +80,29 @@ class TruthfulQAMultipleChoice(Task): ...@@ -80,22 +80,29 @@ class TruthfulQAMultipleChoice(Task):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return QA_PROMPT + "\n\nQ: " + doc['question'] + "\nA:" return QA_PROMPT + "\n\nQ: " + doc["question"] + "\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -105,11 +112,15 @@ class TruthfulQAMultipleChoice(Task): ...@@ -105,11 +112,15 @@ class TruthfulQAMultipleChoice(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
def get_lls(targets): def get_lls(targets):
return [rf.loglikelihood(ctx, " " + t)[0] for t in targets] return [rf.loglikelihood(ctx, " " + t)[0] for t in targets]
# MC1 and MC2 targets are not always the same set of strings so we collect # MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing. # likelihoods separately for simpler processing.
return get_lls(doc['mc1_targets']["choices"]) + get_lls(doc['mc2_targets']["choices"]) return get_lls(doc["mc1_targets"]["choices"]) + get_lls(
doc["mc2_targets"]["choices"]
)
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -121,37 +132,29 @@ class TruthfulQAMultipleChoice(Task): ...@@ -121,37 +132,29 @@ class TruthfulQAMultipleChoice(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
def mc1(lls): def mc1(lls):
# The gold answers in `mc1_targets` are always first (index = `0`). # The gold answers in `mc1_targets` are always first (index = `0`).
return np.argmax(lls) == 0 return np.argmax(lls) == 0
def mc2(lls): def mc2(lls):
# Split on the first `0` as everything before it is true (`1`). # Split on the first `0` as everything before it is true (`1`).
split_idx = list(doc['mc2_targets']["labels"]).index(0) split_idx = list(doc["mc2_targets"]["labels"]).index(0)
# Compute the normalized probability mass for the correct answer. # Compute the normalized probability mass for the correct answer.
ll_true, ll_false = lls[:split_idx], lls[split_idx:] ll_true, ll_false = lls[:split_idx], lls[split_idx:]
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false)) p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false)) p_true = p_true / (sum(p_true) + sum(p_false))
return sum(p_true) return sum(p_true)
split_idx = len(doc['mc1_targets']["choices"]) split_idx = len(doc["mc1_targets"]["choices"])
mc1_lls, mc2_lls = results[:split_idx], results[split_idx:] mc1_lls, mc2_lls = results[:split_idx], results[split_idx:]
return { return {"mc1": mc1(mc1_lls), "mc2": mc2(mc2_lls)}
"mc1": mc1(mc1_lls),
"mc2": mc2(mc2_lls)
}
def aggregation(self): def aggregation(self):
return { return {"mc1": mean, "mc2": mean}
"mc1": mean,
"mc2": mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"mc1": True, "mc2": True}
"mc1": True,
"mc2": True
}
class TruthfulQAGeneration(Task): class TruthfulQAGeneration(Task):
...@@ -181,44 +184,45 @@ class TruthfulQAGeneration(Task): ...@@ -181,44 +184,45 @@ class TruthfulQAGeneration(Task):
answer = answer.strip() answer = answer.strip()
if len(answer): if len(answer):
# Add a period after all answers. # Add a period after all answers.
if answer[-1] != '.': if answer[-1] != ".":
formatted_answers.append(answer + '.') formatted_answers.append(answer + ".")
else: else:
formatted_answers.append(answer) formatted_answers.append(answer)
return formatted_answers return formatted_answers
def validation_docs(self): def validation_docs(self):
for doc in self.dataset["validation"]: for doc in self.dataset["validation"]:
incorrect_answers = self._format_answers(doc['incorrect_answers']) incorrect_answers = self._format_answers(doc["incorrect_answers"])
correct_answers = self._format_answers(doc['correct_answers']) correct_answers = self._format_answers(doc["correct_answers"])
if "I have no comment." not in correct_answers: if "I have no comment." not in correct_answers:
correct_answers.append("I have no comment.") correct_answers.append("I have no comment.")
yield { yield {
'question': doc['question'].strip(), "question": doc["question"].strip(),
'correct_answers': correct_answers, "correct_answers": correct_answers,
'incorrect_answers': incorrect_answers "incorrect_answers": incorrect_answers,
} }
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return QA_PROMPT + "\n\nQ: " + doc['question'] return QA_PROMPT + "\n\nQ: " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " return " "
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -229,7 +233,7 @@ class TruthfulQAGeneration(Task): ...@@ -229,7 +233,7 @@ class TruthfulQAGeneration(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation. # TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation.
completion = rf.greedy_until(ctx, ['.']) completion = rf.greedy_until(ctx, ["."])
return completion return completion
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -243,18 +247,18 @@ class TruthfulQAGeneration(Task): ...@@ -243,18 +247,18 @@ class TruthfulQAGeneration(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
completion = results[0].strip() completion = results[0].strip()
true_refs, false_refs = doc['correct_answers'], doc['incorrect_answers'] true_refs, false_refs = doc["correct_answers"], doc["incorrect_answers"]
all_refs = true_refs + false_refs all_refs = true_refs + false_refs
# Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures. # Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures.
# BLEURT # BLEURT
bleurt_scores_true = self.bleurt.compute( bleurt_scores_true = self.bleurt.compute(
predictions=[completion] * len(true_refs), predictions=[completion] * len(true_refs), references=true_refs
references=true_refs)['scores'] )["scores"]
bleurt_scores_false = self.bleurt.compute( bleurt_scores_false = self.bleurt.compute(
predictions=[completion] * len(false_refs), predictions=[completion] * len(false_refs), references=false_refs
references=false_refs)['scores'] )["scores"]
bleurt_correct = max(bleurt_scores_true) bleurt_correct = max(bleurt_scores_true)
bleurt_incorrect = max(bleurt_scores_false) bleurt_incorrect = max(bleurt_scores_false)
bleurt_max = bleurt_correct bleurt_max = bleurt_correct
...@@ -263,8 +267,8 @@ class TruthfulQAGeneration(Task): ...@@ -263,8 +267,8 @@ class TruthfulQAGeneration(Task):
# BLEU # BLEU
bleu_scores = [self.bleu([[ref]], [completion]) for ref in all_refs] bleu_scores = [self.bleu([[ref]], [completion]) for ref in all_refs]
bleu_correct = np.nanmax(bleu_scores[:len(true_refs)]) bleu_correct = np.nanmax(bleu_scores[: len(true_refs)])
bleu_incorrect = np.nanmax(bleu_scores[len(true_refs):]) bleu_incorrect = np.nanmax(bleu_scores[len(true_refs) :])
bleu_max = bleu_correct bleu_max = bleu_correct
bleu_diff = bleu_correct - bleu_incorrect bleu_diff = bleu_correct - bleu_incorrect
bleu_acc = int(bleu_correct > bleu_incorrect) bleu_acc = int(bleu_correct > bleu_incorrect)
...@@ -272,23 +276,23 @@ class TruthfulQAGeneration(Task): ...@@ -272,23 +276,23 @@ class TruthfulQAGeneration(Task):
# ROUGE-N # ROUGE-N
rouge_scores = [self.rouge([ref], [completion]) for ref in all_refs] rouge_scores = [self.rouge([ref], [completion]) for ref in all_refs]
# ROUGE-1 # ROUGE-1
rouge1_scores = [score['rouge1'] for score in rouge_scores] rouge1_scores = [score["rouge1"] for score in rouge_scores]
rouge1_correct = np.nanmax(rouge1_scores[:len(true_refs)]) rouge1_correct = np.nanmax(rouge1_scores[: len(true_refs)])
rouge1_incorrect = np.nanmax(rouge1_scores[len(true_refs):]) rouge1_incorrect = np.nanmax(rouge1_scores[len(true_refs) :])
rouge1_max = rouge1_correct rouge1_max = rouge1_correct
rouge1_diff = rouge1_correct - rouge1_incorrect rouge1_diff = rouge1_correct - rouge1_incorrect
rouge1_acc = int(rouge1_correct > rouge1_incorrect) rouge1_acc = int(rouge1_correct > rouge1_incorrect)
# ROUGE-2 # ROUGE-2
rouge2_scores = [score['rouge2'] for score in rouge_scores] rouge2_scores = [score["rouge2"] for score in rouge_scores]
rouge2_correct = np.nanmax(rouge2_scores[:len(true_refs)]) rouge2_correct = np.nanmax(rouge2_scores[: len(true_refs)])
rouge2_incorrect = np.nanmax(rouge2_scores[len(true_refs):]) rouge2_incorrect = np.nanmax(rouge2_scores[len(true_refs) :])
rouge2_max = rouge2_correct rouge2_max = rouge2_correct
rouge2_diff = rouge2_correct - rouge2_incorrect rouge2_diff = rouge2_correct - rouge2_incorrect
rouge2_acc = int(rouge2_correct > rouge2_incorrect) rouge2_acc = int(rouge2_correct > rouge2_incorrect)
# ROUGE-L # ROUGE-L
rougeL_scores = [score['rougeLsum'] for score in rouge_scores] rougeL_scores = [score["rougeLsum"] for score in rouge_scores]
rougeL_correct = np.nanmax(rougeL_scores[:len(true_refs)]) rougeL_correct = np.nanmax(rougeL_scores[: len(true_refs)])
rougeL_incorrect = np.nanmax(rougeL_scores[len(true_refs):]) rougeL_incorrect = np.nanmax(rougeL_scores[len(true_refs) :])
rougeL_max = rougeL_correct rougeL_max = rougeL_correct
rougeL_diff = rougeL_correct - rougeL_incorrect rougeL_diff = rougeL_correct - rougeL_incorrect
rougeL_acc = int(rougeL_correct > rougeL_incorrect) rougeL_acc = int(rougeL_correct > rougeL_incorrect)
...@@ -297,19 +301,15 @@ class TruthfulQAGeneration(Task): ...@@ -297,19 +301,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max": bleurt_max, "bleurt_max": bleurt_max,
"bleurt_acc": bleurt_acc, "bleurt_acc": bleurt_acc,
"bleurt_diff": bleurt_diff, "bleurt_diff": bleurt_diff,
"bleu_max": bleu_max, "bleu_max": bleu_max,
"bleu_acc": bleu_acc, "bleu_acc": bleu_acc,
"bleu_diff": bleu_diff, "bleu_diff": bleu_diff,
"rouge1_max": rouge1_max, "rouge1_max": rouge1_max,
"rouge1_acc": rouge1_acc, "rouge1_acc": rouge1_acc,
"rouge1_diff": rouge1_diff, "rouge1_diff": rouge1_diff,
"rouge2_max": rouge2_max, "rouge2_max": rouge2_max,
"rouge2_acc": rouge2_acc, "rouge2_acc": rouge2_acc,
"rouge2_diff": rouge2_diff, "rouge2_diff": rouge2_diff,
"rougeL_max": rougeL_max, "rougeL_max": rougeL_max,
"rougeL_acc": rougeL_acc, "rougeL_acc": rougeL_acc,
"rougeL_diff": rougeL_diff, "rougeL_diff": rougeL_diff,
...@@ -320,19 +320,15 @@ class TruthfulQAGeneration(Task): ...@@ -320,19 +320,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max": mean, "bleurt_max": mean,
"bleurt_acc": mean, "bleurt_acc": mean,
"bleurt_diff": mean, "bleurt_diff": mean,
"bleu_max": mean, "bleu_max": mean,
"bleu_acc": mean, "bleu_acc": mean,
"bleu_diff": mean, "bleu_diff": mean,
"rouge1_max": mean, "rouge1_max": mean,
"rouge1_acc": mean, "rouge1_acc": mean,
"rouge1_diff": mean, "rouge1_diff": mean,
"rouge2_max": mean, "rouge2_max": mean,
"rouge2_acc": mean, "rouge2_acc": mean,
"rouge2_diff": mean, "rouge2_diff": mean,
"rougeL_max": mean, "rougeL_max": mean,
"rougeL_acc": mean, "rougeL_acc": mean,
"rougeL_diff": mean, "rougeL_diff": mean,
...@@ -343,19 +339,15 @@ class TruthfulQAGeneration(Task): ...@@ -343,19 +339,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max": True, "bleurt_max": True,
"bleurt_acc": True, "bleurt_acc": True,
"bleurt_diff": True, "bleurt_diff": True,
"bleu_max": True, "bleu_max": True,
"bleu_acc": True, "bleu_acc": True,
"bleu_diff": True, "bleu_diff": True,
"rouge1_max": True, "rouge1_max": True,
"rouge1_acc": True, "rouge1_acc": True,
"rouge1_diff": True, "rouge1_diff": True,
"rouge2_max": True, "rouge2_max": True,
"rouge2_acc": True, "rouge2_acc": True,
"rouge2_diff": True, "rouge2_diff": True,
"rougeL_max": True, "rougeL_max": True,
"rougeL_acc": True, "rougeL_acc": True,
"rougeL_diff": True, "rougeL_diff": True,
...@@ -379,7 +371,7 @@ class TruthfulQAGeneration(Task): ...@@ -379,7 +371,7 @@ class TruthfulQAGeneration(Task):
force=False, force=False,
lowercase=False, lowercase=False,
tokenize="intl", tokenize="intl",
use_effective_order=False use_effective_order=False,
).score ).score
return score return score
...@@ -396,9 +388,11 @@ class TruthfulQAGeneration(Task): ...@@ -396,9 +388,11 @@ class TruthfulQAGeneration(Task):
rouge_types = ["rouge1", "rouge2", "rougeLsum"] rouge_types = ["rouge1", "rouge2", "rougeLsum"]
scorer = rouge_scorer.RougeScorer(rouge_types) scorer = rouge_scorer.RougeScorer(rouge_types)
# Add newlines between sentences to correctly compute `rougeLsum`. # Add newlines between sentences to correctly compute `rougeLsum`.
def _prepare_summary(summary): def _prepare_summary(summary):
summary = summary.replace(" . ", ".\n") summary = summary.replace(" . ", ".\n")
return summary return summary
# Accumulate confidence intervals. # Accumulate confidence intervals.
aggregator = scoring.BootstrapAggregator() aggregator = scoring.BootstrapAggregator()
for ref, pred in zip(refs, preds): for ref, pred in zip(refs, preds):
...@@ -406,4 +400,4 @@ class TruthfulQAGeneration(Task): ...@@ -406,4 +400,4 @@ class TruthfulQAGeneration(Task):
pred = _prepare_summary(pred) pred = _prepare_summary(pred)
aggregator.add_scores(scorer.score(ref, pred)) aggregator.add_scores(scorer.score(ref, pred))
result = aggregator.aggregate() result = aggregator.aggregate()
return {type: result[type].mid.fmeasure*100 for type in rouge_types} return {type: result[type].mid.fmeasure * 100 for type in rouge_types}
...@@ -49,6 +49,12 @@ class WordUnscrambleTask(Task): ...@@ -49,6 +49,12 @@ class WordUnscrambleTask(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
...@@ -59,19 +65,13 @@ class WordUnscrambleTask(Task): ...@@ -59,19 +65,13 @@ class WordUnscrambleTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = results[0] pred = results[0]
gold = doc["completion"] gold = doc["completion"]
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Anagrams1(WordUnscrambleTask): class Anagrams1(WordUnscrambleTask):
......
...@@ -54,14 +54,20 @@ class WebQs(Task): ...@@ -54,14 +54,20 @@ class WebQs(Task):
return self.dataset["test"] return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc["question"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes # this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible. # multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly # TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0] return " " + doc["answers"][0]
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
...@@ -75,15 +81,13 @@ class WebQs(Task): ...@@ -75,15 +81,13 @@ class WebQs(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answers']): for alias in self._remove_prefixes(doc["answers"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
def process_results(self, doc, results): def process_results(self, doc, results):
return { return {"acc": float(any(results))}
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -91,6 +95,4 @@ class WebQs(Task): ...@@ -91,6 +95,4 @@ class WebQs(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Pointer Sentinel Mixture Models Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia. extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2. NOTE: This `Task` is based on WikiText-2.
...@@ -17,7 +17,7 @@ from lm_eval.base import PerplexityTask ...@@ -17,7 +17,7 @@ from lm_eval.base import PerplexityTask
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
title={Pointer Sentinel Mixture Models}, title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher}, author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016}, year={2016},
eprint={1609.07843}, eprint={1609.07843},
...@@ -90,6 +90,9 @@ class WikiText(PerplexityTask): ...@@ -90,6 +90,9 @@ class WikiText(PerplexityTask):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
""" """
WinoGrande: An Adversarial Winograd Schema Challenge at Scale WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and (Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning. sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018). Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847 See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@article{sakaguchi2019winogrande, @article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale}, title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin}, author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641}, journal={arXiv preprint arXiv:1907.10641},
year={2019} year={2019}
} }
""" """
class Winogrande(Task): class Winogrande(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1} answer_to_num = {"1": 0, "2": 1}
def has_training_docs(self): def has_training_docs(self):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod def should_decontaminate(self):
def partial_context(cls, doc, option): return True
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after. def doc_to_decontamination_query(self, doc):
pronoun_loc = doc["sentence"].index("_") return doc["sentence"]
return doc["sentence"][:pronoun_loc] + option
@classmethod
def doc_to_target(self, doc): def partial_context(cls, doc, option):
return self.partial_target(doc) # Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
@classmethod pronoun_loc = doc["sentence"].index("_")
def partial_target(cls, doc): return doc["sentence"][:pronoun_loc] + option
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1 def doc_to_target(self, doc):
return " " + doc["sentence"][pronoun_loc:].strip() return self.partial_target(doc)
def construct_requests(self, doc, ctx): @classmethod
"""Uses RequestFactory to construct Requests and returns an iterable of def partial_target(cls, doc):
Requests which will be sent to the LM. # The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
:param doc: return " " + doc["sentence"][pronoun_loc:].strip()
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str def construct_requests(self, doc, ctx):
The context string, generated by fewshot_context. This includes the natural """Uses RequestFactory to construct Requests and returns an iterable of
language description, as well as the few shot examples, and the question Requests which will be sent to the LM.
part of the document for `doc`.
""" :param doc:
target = self.partial_target(doc) The document as returned from training_docs, validation_docs, or test_docs.
lls = [] :param ctx: str
for option in [doc["option1"], doc["option2"]]: The context string, generated by fewshot_context. This includes the natural
partial_ctx = self.partial_context(doc, option) language description, as well as the few shot examples, and the question
full_ctx = self.append_context(ctx, partial_ctx) part of the document for `doc`.
lls.append(rf.loglikelihood(full_ctx, target)[0]) """
return lls target = self.partial_target(doc)
lls = []
@classmethod for option in [doc["option1"], doc["option2"]]:
def append_context(cls, ctx, partial_ctx): partial_ctx = self.partial_context(doc, option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line. full_ctx = self.append_context(ctx, partial_ctx)
ctx.pop() # Remove the correct context put in by `doc_to_text`. lls.append(rf.loglikelihood(full_ctx, target)[0])
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx return lls
def process_results(self, doc, results): @classmethod
"""Take a single document and the LM results and evaluates, returning a def append_context(cls, ctx, partial_ctx):
dict where keys are the names of submetrics and values are the values of ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
the metric for that one document ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
:param doc:
The document as returned from training_docs, validation_docs, or test_docs. def process_results(self, doc, results):
:param results: """Take a single document and the LM results and evaluates, returning a
The results of the requests created in construct_requests. dict where keys are the names of submetrics and values are the values of
""" the metric for that one document
return {
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]] :param doc:
} The document as returned from training_docs, validation_docs, or test_docs.
:param results:
def aggregation(self): The results of the requests created in construct_requests.
""" """
:returns: {str: [float] -> float} return {"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def aggregation(self):
""" """
return { :returns: {str: [float] -> float}
"acc": mean A dictionary where keys are the names of submetrics and values are
} functions that aggregate a list of metrics
"""
def higher_is_better(self): return {"acc": mean}
"""
:returns: {str: bool} def higher_is_better(self):
A dictionary where keys are the names of submetrics and values are """
whether a higher value of the submetric is better :returns: {str: bool}
""" A dictionary where keys are the names of submetrics and values are
return { whether a higher value of the submetric is better
"acc": True """
} return {"acc": True}
...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task): ...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
upper_pronouns = ["A", "An", "The", "She", "He", upper_pronouns = [
"It", "They", "My", "His", "Her", "Their"] "A",
"An",
"The",
"She",
"He",
"It",
"They",
"My",
"His",
"Her",
"Their",
]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
option += "'s" option += "'s"
# Appropriately lowercase the pronoun in the option. # Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0] pronoun = option.split()[0]
start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.' start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
if not start_of_sentence and pronoun in self.upper_pronouns: if not start_of_sentence and pronoun in self.upper_pronouns:
return option.replace(pronoun, pronoun.lower()) return option.replace(pronoun, pronoun.lower())
return option return option
...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task): ...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]]) return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified # Substitute the pronoun in the original text with the specified
# option and ignore everything after. # option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option return doc["text"][: doc["pronoun_loc"]] + option
def doc_to_target(self, doc): def doc_to_target(self, doc):
return self.partial_target(doc) return self.partial_target(doc)
...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
return { return {"acc": np.argmax(results) == doc["label"]}
"acc": np.argmax(results) == doc["label"]
}
def aggregation(self): def aggregation(self):
""" """
...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task): ...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -34,6 +34,7 @@ def simple_parse_args_string(args_string): ...@@ -34,6 +34,7 @@ def simple_parse_args_string(args_string):
args_dict[k] = v args_dict[k] = v
return args_dict return args_dict
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
...@@ -46,23 +47,26 @@ def chunks(iter, n): ...@@ -46,23 +47,26 @@ def chunks(iter, n):
if len(arr) == n: if len(arr) == n:
yield arr yield arr
arr = [] arr = []
if arr: yield arr if arr:
yield arr
def group(arr, fn): def group(arr, fn):
res = collections.defaultdict(list) res = collections.defaultdict(list)
for ob in arr: for ob in arr:
res[fn(ob)].append(ob) res[fn(ob)].append(ob)
return list(res.values()) return list(res.values())
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
string = string.replace("\" ", "\"") string = string.replace('" ', '"')
string = string.replace(" \"", "\"") string = string.replace(' "', '"')
string = re.sub(r" (['.,])", r"\1", string) string = re.sub(r" (['.,])", r"\1", string)
return string return string
...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens # Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list)) first_seq_len = min(max_seq_len, len(token_list))
yield ( yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len predicted += first_seq_len
while predicted < len(token_list): while predicted < len(token_list):
...@@ -105,61 +106,66 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -105,61 +106,66 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end = predicted + window_pred_len window_end = predicted + window_pred_len
yield ( yield (
token_list[window_end - max_seq_len - 1:window_end - 1], token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len:window_end], token_list[window_end - window_pred_len : window_end],
) )
predicted += window_pred_len predicted += window_pred_len
def make_disjoint_window(pair): def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """ """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[:-(len(b) - 1)], b return a[: -(len(b) - 1)], b
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
arr = [ arr = [([y[0] for y in x], x[0][1]) for x in arr]
([y[0] for y in x], x[0][1]) for x in arr
]
arr.sort(key=lambda x: fn(x[1])) arr.sort(key=lambda x: fn(x[1]))
self.arr = arr self.arr = arr
def get_reordered(self): def get_reordered(self):
return [x[1] for x in self.arr] return [x[1] for x in self.arr]
def get_original(self, newarr): def get_original(self, newarr):
res = [None] * self.size res = [None] * self.size
cov = [False] * self.size cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr): for (inds, _), v in zip(self.arr, newarr):
for ind in inds: for ind in inds:
res[ind] = v res[ind] = v
cov[ind] = True cov[ind] = True
assert all(cov) assert all(cov)
return res return res
def positional_deprecated(fn): def positional_deprecated(fn):
""" """
A decorator to nudge users into passing only keyword args (`kwargs`) to the A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) @functools.wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is " print(
f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of " "deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!") "lm-evaluation-harness!"
)
return fn(*args, **kwargs) return fn(*args, **kwargs)
return _wrapper return _wrapper
@positional_deprecated @positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
""" """
...@@ -169,12 +175,14 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path: ...@@ -169,12 +175,14 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
cur_path = start_path.resolve() cur_path = start_path.resolve()
max_layers = 3 max_layers = 3
for _ in range(max_layers): for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists(): if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path return cur_path
else: else:
cur_path = cur_path.parent.resolve() cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\ raise FileNotFoundError(
f"of {start_path}") f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated @positional_deprecated
def run_task_tests(task_list: List[str]): def run_task_tests(task_list: List[str]):
...@@ -182,9 +190,16 @@ def run_task_tests(task_list: List[str]): ...@@ -182,9 +190,16 @@ def run_task_tests(task_list: List[str]):
Find the package root and run the tests for the given tasks Find the package root and run the tests for the given tasks
""" """
package_root = find_test_root(start_path=pathlib.Path(__file__)) package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = ' or '.join(task_list) task_string = " or ".join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}'] args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root)) sys.path.append(str(package_root))
pytest_return_val = pytest.main(args) pytest_return_val = pytest.main(args)
if pytest_return_val: if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}") raise ValueError(
\ No newline at end of file f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
import argparse import argparse
import json import json
import logging import logging
import fnmatch
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument("--model", required=True)
parser.add_argument('--model_args', default="") parser.add_argument("--model_args", default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument("--output_path", default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument("--limit", type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks": if args.tasks is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -51,11 +86,11 @@ def main(): ...@@ -51,11 +86,11 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
check_integrity=args.check_integrity decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
janitor.py contains a script to remove benchmark data contamination from training data sets. janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165). It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165).
## Algorithm ## Algorithm
1) Collects all contamination text files that are to be removed from training data 1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding `N`gram matches between the training data 2) Filters training data by finding `N`gram matches between the training data
and any contamination and any contamination
1) `N`grams ignore case and punctation and are split on whitespace. 1) `N`grams ignore case and punctuation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around 2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
the match, splitting the training data into chunks the match, splitting the training data into chunks
3) Any chunks less than `minimum_slice_length` are removed 3) Any chunks less than `minimum_slice_length` are removed
4) Training data sets split into more than `too_dirty_cutoff` are considered 4) Training data sets split into more than `too_dirty_cutoff` are considered
completey contaminated and removed completey contaminated and removed
OpenAI used: OpenAI used:
``` ```
ngram_n = 13 ngram_n = 13
...@@ -20,7 +20,7 @@ minimum_slice_length = 200 ...@@ -20,7 +20,7 @@ minimum_slice_length = 200
too_dirty_cutoff = 10 too_dirty_cutoff = 10
``` ```
## Compling ## Compiling
Janitor can be used as a pure python program, but it is much faster if the ngram Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run code is run in C++. To compile the C++ code, run
...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor ...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
``` ```
If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup` If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup`
import glob
import argparse
import os
import subprocess
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
def process_task(
working_directory, output_directory, bucket_file_path, tqdm_func, global_tqdm
):
command = f"zstd {bucket_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
compressed_file = bucket_file_path + ".zst"
if output_directory:
shutil.move(compressed_file, output_directory)
os.remove(bucket_file_path)
global_tqdm.update()
def compress_and_move(working_directory, output_directory, process_count):
os.makedirs(output_directory, exist_ok=True)
original_info_file_path = os.path.join(working_directory, "info.json")
assert os.path.exists(original_info_file_path)
tasks = []
bucket_file_paths = glob.glob(
os.path.join(working_directory, "output", f"*.bkt.txt.sorted")
)
for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path))
tasks.append(task)
pool = TqdmMultiProcessPool(process_count)
def on_done(_):
return None
def on_error(_):
return None
global_progress = tqdm(
total=len(bucket_file_paths), dynamic_ncols=True, unit="file"
)
_ = pool.map(global_progress, tasks, on_error, on_done)
shutil.copy(original_info_file_path, os.path.join(output_directory, "info.json"))
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", required=True)
parser.add_argument("-output", "--output_directory", required=True)
parser.add_argument("-procs", "--process_count", type=int, default=8)
if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
logfile_path = "compress_and_package.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
compress_and_move(args.working_directory, args.output_directory, args.process_count)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment