Commit 121b7096 authored by Fabrizio Milo's avatar Fabrizio Milo
Browse files

add pre-commit

parent 7a038118
...@@ -68,7 +68,9 @@ class CoLA(Task): ...@@ -68,7 +68,9 @@ class CoLA(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(
doc["sentence"]
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -88,19 +90,13 @@ class CoLA(Task): ...@@ -88,19 +90,13 @@ class CoLA(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"mcc": (gold, pred)}
"mcc": (gold, pred)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"mcc": True}
"mcc": True
}
def aggregation(self): def aggregation(self):
return { return {"mcc": matthews_corrcoef}
"mcc": matthews_corrcoef
}
class SST(Task): class SST(Task):
...@@ -142,19 +138,13 @@ class SST(Task): ...@@ -142,19 +138,13 @@ class SST(Task):
ll_positive, ll_negative = results ll_positive, ll_negative = results
pred = ll_positive > ll_negative pred = ll_positive > ll_negative
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Inference Tasks # Inference Tasks
...@@ -190,7 +180,8 @@ class MNLI(Task): ...@@ -190,7 +180,8 @@ class MNLI(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'), doc["hypothesis"].strip()
+ ("" if doc["hypothesis"].strip().endswith(".") else "."),
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -208,19 +199,13 @@ class MNLI(Task): ...@@ -208,19 +199,13 @@ class MNLI(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
...@@ -258,9 +243,11 @@ class QNLI(Task): ...@@ -258,9 +243,11 @@ class QNLI(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return (
doc["question"], "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["sentence"], doc["question"],
doc["sentence"],
)
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -277,19 +264,13 @@ class QNLI(Task): ...@@ -277,19 +264,13 @@ class QNLI(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_no > ll_yes pred = ll_no > ll_yes
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class WNLI(Task): class WNLI(Task):
...@@ -334,19 +315,13 @@ class WNLI(Task): ...@@ -334,19 +315,13 @@ class WNLI(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class RTE(Task): class RTE(Task):
...@@ -391,19 +366,13 @@ class RTE(Task): ...@@ -391,19 +366,13 @@ class RTE(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_false > ll_true pred = ll_false > ll_true
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
...@@ -455,16 +424,10 @@ class MRPC(Task): ...@@ -455,16 +424,10 @@ class MRPC(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class QQP(Task): class QQP(Task):
...@@ -513,16 +476,10 @@ class QQP(Task): ...@@ -513,16 +476,10 @@ class QQP(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class STSB(Task): class STSB(Task):
...@@ -560,22 +517,22 @@ class STSB(Task): ...@@ -560,22 +517,22 @@ class STSB(Task):
return " {}".format(doc["label"]) return " {}".format(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -584,22 +541,22 @@ class STSB(Task): ...@@ -584,22 +541,22 @@ class STSB(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
"Training Verifiers to Solve Math Word Problems" "Training Verifiers to Solve Math Word Problems"
https://arxiv.org/abs/2110.14168 https://arxiv.org/abs/2110.14168
State-of-the-art language models can match human performance on many tasks, but State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To they still struggle to robustly perform multi-step mathematical reasoning. To
diagnose the failures of current models and support research, we introduce GSM8K, diagnose the failures of current models and support research, we introduce GSM8K,
a dataset of 8.5K high quality linguistically diverse grade school math word problems. a dataset of 8.5K high quality linguistically diverse grade school math word problems.
We find that even the largest transformer models fail to achieve high test performance, We find that even the largest transformer models fail to achieve high test performance,
despite the conceptual simplicity of this problem distribution. despite the conceptual simplicity of this problem distribution.
NOTE: See the official implementation of the task: NOTE: See the official implementation of the task:
https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
for how to make use of the dataset's calculator annotations in your language for how to make use of the dataset's calculator annotations in your language
model's sample/generation function. model's sample/generation function.
...@@ -64,13 +64,13 @@ class GradeSchoolMath8K(Task): ...@@ -64,13 +64,13 @@ class GradeSchoolMath8K(Task):
return self.dataset["test"] return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc["question"] + "\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['answer'] return " " + doc["answer"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -80,10 +80,10 @@ class GradeSchoolMath8K(Task): ...@@ -80,10 +80,10 @@ class GradeSchoolMath8K(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
completion = rf.greedy_until(ctx, ['\n']) completion = rf.greedy_until(ctx, ["\n"])
return completion return completion
def _extract_answer(self, completion): def _extract_answer(self, completion):
match = ANS_RE.search(completion) match = ANS_RE.search(completion)
...@@ -97,7 +97,7 @@ class GradeSchoolMath8K(Task): ...@@ -97,7 +97,7 @@ class GradeSchoolMath8K(Task):
def _is_correct(self, completion, answer): def _is_correct(self, completion, answer):
gold = self._extract_answer(answer) gold = self._extract_answer(answer)
assert gold != INVALID_ANS, "No ground truth answer found in the document." assert gold != INVALID_ANS, "No ground truth answer found in the document."
return self._extract_answer(completion) == gold return self._extract_answer(completion) == gold
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -111,9 +111,7 @@ class GradeSchoolMath8K(Task): ...@@ -111,9 +111,7 @@ class GradeSchoolMath8K(Task):
""" """
completion = results[0] completion = results[0]
answer = doc["answer"] answer = doc["answer"]
return { return {"acc": self._is_correct(completion, answer)}
"acc": self._is_correct(completion, answer)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -121,9 +119,7 @@ class GradeSchoolMath8K(Task): ...@@ -121,9 +119,7 @@ class GradeSchoolMath8K(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -131,6 +127,4 @@ class GradeSchoolMath8K(Task): ...@@ -131,6 +127,4 @@ class GradeSchoolMath8K(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering
https://aclanthology.org/P19-1092.pdf https://aclanthology.org/P19-1092.pdf
HEAD-QA is a multi-choice HEAlthcare Dataset. The questions come from exams to HEAD-QA is a multi-choice HEAlthcare Dataset. The questions come from exams to
access a specialized position in the Spanish healthcare system, and are challenging access a specialized position in the Spanish healthcare system, and are challenging
even for highly specialized humans. even for highly specialized humans.
...@@ -15,7 +15,7 @@ from lm_eval.base import MultipleChoiceTask ...@@ -15,7 +15,7 @@ from lm_eval.base import MultipleChoiceTask
_CITATION = """ _CITATION = """
@misc{liu2020interpretable, @misc{liu2020interpretable,
title={Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering}, title={Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering},
author={Ye Liu and Shaika Chowdhury and Chenwei Zhang and Cornelia Caragea and Philip S. Yu}, author={Ye Liu and Shaika Chowdhury and Chenwei Zhang and Cornelia Caragea and Philip S. Yu},
year={2020}, year={2020},
eprint={2008.02434}, eprint={2008.02434},
...@@ -82,4 +82,6 @@ class HeadQAEsDeprecated(HeadQABase): ...@@ -82,4 +82,6 @@ class HeadQAEsDeprecated(HeadQABase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.") print(
"WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info."
)
""" """
HellaSwag: Can a Machine Really Finish Your Sentence? HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf https://arxiv.org/pdf/1905.07830.pdf
Hellaswag is a commonsense inference challenge dataset. Though its questions are Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks' the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models. state-of-the-art models.
Homepage: https://rowanzellers.com/hellaswag/ Homepage: https://rowanzellers.com/hellaswag/
""" """
import re import re
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
_CITATION = """ _CITATION = """
@inproceedings{zellers2019hellaswag, @inproceedings{zellers2019hellaswag,
title={HellaSwag: Can a Machine Really Finish Your Sentence?}, title={HellaSwag: Can a Machine Really Finish Your Sentence?},
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin}, author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
year={2019} year={2019}
} }
""" """
class HellaSwag(MultipleChoiceTask): class HellaSwag(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"])) self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._process_doc, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc): def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = { out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx), "query": self.preprocess(doc["activity_label"] + ": " + ctx),
"choices": [self.preprocess(ending) for ending in doc['endings']], "choices": [self.preprocess(ending) for ending in doc["endings"]],
"gold": int(doc['label']), "gold": int(doc["label"]),
} }
return out_doc return out_doc
@classmethod @classmethod
def preprocess(cls, text): def preprocess(cls, text):
text = text.strip() text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ") text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text) text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ") text = text.replace(" ", " ")
return text return text
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc["query"] return doc["query"]
...@@ -108,19 +108,13 @@ class EthicsCM(Ethics): ...@@ -108,19 +108,13 @@ class EthicsCM(Ethics):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc["label"])) gold = bool(int(doc["label"]))
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
class EthicsDeontology(Ethics): class EthicsDeontology(Ethics):
...@@ -129,7 +123,9 @@ class EthicsDeontology(Ethics): ...@@ -129,7 +123,9 @@ class EthicsDeontology(Ethics):
def doc_to_text(self, doc): def doc_to_text(self, doc):
prompt = " ".join([doc["scenario"], doc["excuse"]]) prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt) return 'Question: Would most people believe this reasonable or unreasonable to say? "{}"\nAnswer:'.format(
prompt
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -149,30 +145,27 @@ class EthicsDeontology(Ethics): ...@@ -149,30 +145,27 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc["label"])) gold = bool(int(doc["label"]))
return { return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct) # NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [
int(preds_sort[4 * i][1])
+ int(preds_sort[4 * i + 1][1])
+ int(preds_sort[4 * i + 2][1])
+ int(preds_sort[4 * i + 3][1])
for i in range(len(preds_sort) // 4)
]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
return mean(em_cors) return mean(em_cors)
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "em": self.calc_em}
'acc': mean,
'em': self.calc_em
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "em": True}
'acc': True,
'em': True
}
class EthicsJustice(Ethics): class EthicsJustice(Ethics):
...@@ -180,7 +173,9 @@ class EthicsJustice(Ethics): ...@@ -180,7 +173,9 @@ class EthicsJustice(Ethics):
DATASET_NAME = "justice" DATASET_NAME = "justice"
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"]) return 'Question: Would most people believe this reasonable or unreasonable to say? "{}"\nAnswer:'.format(
doc["scenario"]
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -200,30 +195,27 @@ class EthicsJustice(Ethics): ...@@ -200,30 +195,27 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = bool(int(doc["label"])) gold = bool(int(doc["label"]))
return { return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct # Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct) # NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)] em_sums = [
int(preds_sort[4 * i][1])
+ int(preds_sort[4 * i + 1][1])
+ int(preds_sort[4 * i + 2][1])
+ int(preds_sort[4 * i + 3][1])
for i in range(len(preds_sort) // 4)
]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))] em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
return mean(em_cors) return mean(em_cors)
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "em": self.calc_em}
'acc': mean,
'em': self.calc_em
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "em": True}
'acc': True,
'em': True
}
class EthicsUtilitarianismOriginal(Ethics): class EthicsUtilitarianismOriginal(Ethics):
...@@ -236,14 +228,25 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -236,14 +228,25 @@ class EthicsUtilitarianismOriginal(Ethics):
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5 # Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more." assert (
k <= 5
), "There are only 5 possible shots for this task. Refer to the V2 for more."
# These prompts implicitly (under)specify the task utility function # These prompts implicitly (under)specify the task utility function
prompts = [ prompts = [
{"activity": "I walked in the park this afternoon.", "rating": "7"}, {"activity": "I walked in the park this afternoon.", "rating": "7"},
{"activity": "I walked in the parking lot this afternoon.", "rating": "5"}, {"activity": "I walked in the parking lot this afternoon.", "rating": "5"},
{"activity": "I stopped to eat at a fast food restaurant. The food was cold.", "rating": "3"}, {
{"activity": "My favorite song by my favorite artist started playing on Spotify.", "rating": "8"}, "activity": "I stopped to eat at a fast food restaurant. The food was cold.",
{"activity": "I tripped and fell down the stairs at school. Several classmates made fun of me for falling.", "rating": "2"}, "rating": "3",
},
{
"activity": "My favorite song by my favorite artist started playing on Spotify.",
"rating": "8",
},
{
"activity": "I tripped and fell down the stairs at school. Several classmates made fun of me for falling.",
"rating": "2",
},
] ]
return rnd.sample(prompts, k) return rnd.sample(prompts, k)
...@@ -282,14 +285,10 @@ class EthicsUtilitarianismOriginal(Ethics): ...@@ -282,14 +285,10 @@ class EthicsUtilitarianismOriginal(Ethics):
} }
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
class EthicsUtilitarianism(Ethics): class EthicsUtilitarianism(Ethics):
...@@ -297,6 +296,7 @@ class EthicsUtilitarianism(Ethics): ...@@ -297,6 +296,7 @@ class EthicsUtilitarianism(Ethics):
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared. This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots. This allows scaling to >5 shots.
""" """
VERSION = 0 VERSION = 0
DATASET_NAME = "utilitarianism" DATASET_NAME = "utilitarianism"
...@@ -339,19 +339,13 @@ class EthicsUtilitarianism(Ethics): ...@@ -339,19 +339,13 @@ class EthicsUtilitarianism(Ethics):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
class EthicsVirtue(Ethics): class EthicsVirtue(Ethics):
...@@ -362,9 +356,8 @@ class EthicsVirtue(Ethics): ...@@ -362,9 +356,8 @@ class EthicsVirtue(Ethics):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format( return 'Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait "{}"?\nAnswer:'.format(
doc["scenario"], doc["scenario"], doc["trait"]
doc["trait"]
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -379,27 +372,25 @@ class EthicsVirtue(Ethics): ...@@ -379,27 +372,25 @@ class EthicsVirtue(Ethics):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_yes > ll_no pred = ll_yes > ll_no
gold = bool(int(doc["label"])) gold = bool(int(doc["label"]))
return { return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
def calc_em(self, items): def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct # Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct) # NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0]) preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)] em_sums = [
int(preds_sort[5 * i][1])
+ int(preds_sort[5 * i + 1][1])
+ int(preds_sort[5 * i + 2][1])
+ int(preds_sort[5 * i + 3][1])
+ int(preds_sort[5 * i + 4][1])
for i in range(len(preds_sort) // 5)
]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))] em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
return mean(em_cors) return mean(em_cors)
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "em": self.calc_em}
'acc': mean,
'em': self.calc_em
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "em": True}
'acc': True,
'em': True
}
...@@ -47,8 +47,7 @@ class Math(Task): ...@@ -47,8 +47,7 @@ class Math(Task):
return map(self._process_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc): def _process_doc(self, doc):
doc["answer"] = self.remove_boxed( doc["answer"] = self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
self.last_boxed_only_string(doc["solution"]))
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -72,23 +71,19 @@ class Math(Task): ...@@ -72,23 +71,19 @@ class Math(Task):
if len(indices) <= 1: if len(indices) <= 1:
answer = results[0] answer = results[0]
else: else:
answer = results[0][indices[0]+1:indices[-1]] answer = results[0][indices[0] + 1 : indices[-1]]
if self.is_equiv(answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))): if self.is_equiv(
answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1 retval = 1
return { return {"acc": retval}
"acc": retval
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
def is_equiv(self, str1, str2, verbose=False): def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None: if str1 is None and str2 is None:
...@@ -109,18 +104,18 @@ class Math(Task): ...@@ -109,18 +104,18 @@ class Math(Task):
def remove_boxed(self, s): def remove_boxed(self, s):
if "\\boxed " in s: if "\\boxed " in s:
left = "\\boxed " left = "\\boxed "
assert s[:len(left)] == left assert s[: len(left)] == left
return s[len(left):] return s[len(left) :]
left = "\\boxed{" left = "\\boxed{"
assert s[:len(left)] == left assert s[: len(left)] == left
assert s[-1] == "}" assert s[-1] == "}"
return s[len(left):-1] return s[len(left) : -1]
def last_boxed_only_string(self, string): def last_boxed_only_string(self, string):
idx = string.rfind("\\boxed") idx = string.rfind("\\boxed")
if "\\boxed " in string: if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
...@@ -145,7 +140,7 @@ class Math(Task): ...@@ -145,7 +140,7 @@ class Math(Task):
if right_brace_idx is None: if right_brace_idx is None:
retval = None retval = None
else: else:
retval = string[idx:right_brace_idx + 1] retval = string[idx : right_brace_idx + 1]
return retval return retval
...@@ -288,34 +283,34 @@ class Math(Task): ...@@ -288,34 +283,34 @@ class Math(Task):
class MathAlgebra(Math): class MathAlgebra(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'algebra' DATASET_NAME = "algebra"
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'counting_and_probability' DATASET_NAME = "counting_and_probability"
class MathGeometry(Math): class MathGeometry(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'geometry' DATASET_NAME = "geometry"
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'intermediate_algebra' DATASET_NAME = "intermediate_algebra"
class MathNumberTheory(Math): class MathNumberTheory(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'number_theory' DATASET_NAME = "number_theory"
class MathPrealgebra(Math): class MathPrealgebra(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'prealgebra' DATASET_NAME = "prealgebra"
class MathPrecalculus(Math): class MathPrecalculus(Math):
VERSION = 1 VERSION = 1
DATASET_NAME = 'precalculus' DATASET_NAME = "precalculus"
...@@ -3,11 +3,11 @@ Measuring Massive Multitask Language Understanding ...@@ -3,11 +3,11 @@ Measuring Massive Multitask Language Understanding
https://arxiv.org/pdf/2009.03300.pdf https://arxiv.org/pdf/2009.03300.pdf
The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy. The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy.
The test covers 57 tasks including elementary mathematics, US history, computer The test covers 57 tasks including elementary mathematics, US history, computer
science, law, and more. To attain high accuracy on this test, models must possess science, law, and more. To attain high accuracy on this test, models must possess
extensive world knowledge and problem solving ability. By comprehensively evaluating extensive world knowledge and problem solving ability. By comprehensively evaluating
the breadth and depth of a model’s academic and professional understanding, the breadth and depth of a model’s academic and professional understanding,
Hendryck's Test can be used to analyze models across many tasks and to identify Hendryck's Test can be used to analyze models across many tasks and to identify
important shortcomings. important shortcomings.
Homepage: https://github.com/hendrycks/test Homepage: https://github.com/hendrycks/test
...@@ -25,16 +25,65 @@ _CITATION = """ ...@@ -25,16 +25,65 @@ _CITATION = """
""" """
SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', SUBJECTS = [
'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', "abstract_algebra",
'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', "anatomy",
'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', "astronomy",
'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', "business_ethics",
'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', "clinical_knowledge",
'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', "college_biology",
'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', "college_chemistry",
'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', "college_computer_science",
'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'] "college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def create_all_tasks(): def create_all_tasks():
...@@ -42,15 +91,14 @@ def create_all_tasks(): ...@@ -42,15 +91,14 @@ def create_all_tasks():
:return: {task_name: task} :return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task} e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
""" """
return { return {f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS}
f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS
}
def create_task(subject): def create_task(subject):
class HendrycksTest(GeneralHendrycksTest): class HendrycksTest(GeneralHendrycksTest):
def __init__(self): def __init__(self):
super().__init__(subject) super().__init__(subject)
return HendrycksTest return HendrycksTest
...@@ -81,27 +129,32 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -81,27 +129,32 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def _process_doc(self, doc): def _process_doc(self, doc):
def format_example(doc, keys): def format_example(doc, keys):
""" """
Question: <prompt> Question: <prompt>
Choices: Choices:
A. <choice1> A. <choice1>
B. <choice2> B. <choice2>
C. <choice3> C. <choice3>
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc["question"] + "\nChoices:\n" prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]) prompt += "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:" prompt += "Answer:"
return prompt return prompt
keys = ['A', 'B', 'C', 'D']
keys = ["A", "B", "C", "D"]
return { return {
"query": format_example(doc, keys), "query": format_example(doc, keys),
"choices": doc["choices"], "choices": doc["choices"],
"gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"] "gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
} }
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None: if self._fewshot_docs is None:
......
...@@ -20,7 +20,7 @@ from lm_eval.metrics import mean, perplexity ...@@ -20,7 +20,7 @@ from lm_eval.metrics import mean, perplexity
_CITATION = """ _CITATION = """
@misc{ @misc{
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel}, author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
title={The LAMBADA dataset}, title={The LAMBADA dataset},
DOI={10.5281/zenodo.2630551}, DOI={10.5281/zenodo.2630551},
publisher={Zenodo}, publisher={Zenodo},
...@@ -53,38 +53,29 @@ class LAMBADA(Task): ...@@ -53,38 +53,29 @@ class LAMBADA(Task):
pass pass
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] return doc["text"].rsplit(" ", 1)[0]
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['text'] return doc["text"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1] return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc)) ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy return ll, is_greedy
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_greedy = results ll, is_greedy = results
return { return {"ppl": ll, "acc": int(is_greedy)}
'ppl': ll,
'acc': int(is_greedy)
}
def aggregation(self): def aggregation(self):
return { return {"ppl": perplexity, "acc": mean}
'ppl': perplexity,
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"ppl": False, "acc": True}
'ppl': False,
'acc': True
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -5,7 +5,7 @@ https://arxiv.org/pdf/1911.11641.pdf ...@@ -5,7 +5,7 @@ https://arxiv.org/pdf/1911.11641.pdf
Physical Interaction: Question Answering (PIQA) is a physical commonsense Physical Interaction: Question Answering (PIQA) is a physical commonsense
reasoning and a corresponding benchmark dataset. PIQA was designed to investigate reasoning and a corresponding benchmark dataset. PIQA was designed to investigate
the physical knowledge of existing models. To what extent are current approaches the physical knowledge of existing models. To what extent are current approaches
actually learning about the world? actually learning about the world?
Homepage: https://yonatanbisk.com/piqa/ Homepage: https://yonatanbisk.com/piqa/
""" """
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment