Commit 121b7096 authored by Fabrizio Milo's avatar Fabrizio Milo
Browse files

add pre-commit

parent 7a038118
......@@ -68,7 +68,9 @@ class CoLA(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(
doc["sentence"]
)
def should_decontaminate(self):
return True
......@@ -88,19 +90,13 @@ class CoLA(Task):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return {
"mcc": (gold, pred)
}
return {"mcc": (gold, pred)}
def higher_is_better(self):
return {
"mcc": True
}
return {"mcc": True}
def aggregation(self):
return {
"mcc": matthews_corrcoef
}
return {"mcc": matthews_corrcoef}
class SST(Task):
......@@ -142,19 +138,13 @@ class SST(Task):
ll_positive, ll_negative = results
pred = ll_positive > ll_negative
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
# Inference Tasks
......@@ -190,7 +180,8 @@ class MNLI(Task):
def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'),
doc["hypothesis"].strip()
+ ("" if doc["hypothesis"].strip().endswith(".") else "."),
)
def doc_to_target(self, doc):
......@@ -208,19 +199,13 @@ class MNLI(Task):
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class MNLIMismatched(MNLI):
......@@ -258,9 +243,11 @@ class QNLI(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"],
doc["sentence"],
return (
"{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"],
doc["sentence"],
)
)
def doc_to_target(self, doc):
......@@ -277,19 +264,13 @@ class QNLI(Task):
ll_yes, ll_no = results
pred = ll_no > ll_yes
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class WNLI(Task):
......@@ -334,19 +315,13 @@ class WNLI(Task):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class RTE(Task):
......@@ -391,19 +366,13 @@ class RTE(Task):
ll_true, ll_false = results
pred = ll_false > ll_true
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
# Similarity and Paraphrase Tasks
......@@ -455,16 +424,10 @@ class MRPC(Task):
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class QQP(Task):
......@@ -513,16 +476,10 @@ class QQP(Task):
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class STSB(Task):
......@@ -560,22 +517,22 @@ class STSB(Task):
return " {}".format(doc["label"])
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -584,22 +541,22 @@ class STSB(Task):
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
......@@ -2,14 +2,14 @@
"Training Verifiers to Solve Math Word Problems"
https://arxiv.org/abs/2110.14168
State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To
State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To
diagnose the failures of current models and support research, we introduce GSM8K,
a dataset of 8.5K high quality linguistically diverse grade school math word problems.
We find that even the largest transformer models fail to achieve high test performance,
We find that even the largest transformer models fail to achieve high test performance,
despite the conceptual simplicity of this problem distribution.
NOTE: See the official implementation of the task:
NOTE: See the official implementation of the task:
https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
for how to make use of the dataset's calculator annotations in your language
model's sample/generation function.
......@@ -64,13 +64,13 @@ class GradeSchoolMath8K(Task):
return self.dataset["test"]
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
return "Question: " + doc["question"] + "\nAnswer:"
def doc_to_target(self, doc):
return " " + doc['answer']
return " " + doc["answer"]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -80,10 +80,10 @@ class GradeSchoolMath8K(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
completion = rf.greedy_until(ctx, ['\n'])
return completion
completion = rf.greedy_until(ctx, ["\n"])
return completion
def _extract_answer(self, completion):
match = ANS_RE.search(completion)
......@@ -97,7 +97,7 @@ class GradeSchoolMath8K(Task):
def _is_correct(self, completion, answer):
gold = self._extract_answer(answer)
assert gold != INVALID_ANS, "No ground truth answer found in the document."
return self._extract_answer(completion) == gold
return self._extract_answer(completion) == gold
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -111,9 +111,7 @@ class GradeSchoolMath8K(Task):
"""
completion = results[0]
answer = doc["answer"]
return {
"acc": self._is_correct(completion, answer)
}
return {"acc": self._is_correct(completion, answer)}
def aggregation(self):
"""
......@@ -121,9 +119,7 @@ class GradeSchoolMath8K(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -131,6 +127,4 @@ class GradeSchoolMath8K(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
......@@ -2,7 +2,7 @@
Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering
https://aclanthology.org/P19-1092.pdf
HEAD-QA is a multi-choice HEAlthcare Dataset. The questions come from exams to
HEAD-QA is a multi-choice HEAlthcare Dataset. The questions come from exams to
access a specialized position in the Spanish healthcare system, and are challenging
even for highly specialized humans.
......@@ -15,7 +15,7 @@ from lm_eval.base import MultipleChoiceTask
_CITATION = """
@misc{liu2020interpretable,
title={Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering},
title={Interpretable Multi-Step Reasoning with Knowledge Extraction on Complex Healthcare Question Answering},
author={Ye Liu and Shaika Chowdhury and Chenwei Zhang and Cornelia Caragea and Philip S. Yu},
year={2020},
eprint={2008.02434},
......@@ -82,4 +82,6 @@ class HeadQAEsDeprecated(HeadQABase):
def __init__(self):
super().__init__()
print("WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info.")
print(
"WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info."
)
"""
HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf
Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.
Homepage: https://rowanzellers.com/hellaswag/
"""
import re
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@inproceedings{zellers2019hellaswag,
title={HellaSwag: Can a Machine Really Finish Your Sentence?},
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
year={2019}
}
"""
class HellaSwag(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hellaswag"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx),
"choices": [self.preprocess(ending) for ending in doc['endings']],
"gold": int(doc['label']),
}
return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
"""
HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf
Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.
Homepage: https://rowanzellers.com/hellaswag/
"""
import re
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@inproceedings{zellers2019hellaswag,
title={HellaSwag: Can a Machine Really Finish Your Sentence?},
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
year={2019}
}
"""
class HellaSwag(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hellaswag"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc["activity_label"] + ": " + ctx),
"choices": [self.preprocess(ending) for ending in doc["endings"]],
"gold": int(doc["label"]),
}
return out_doc
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -108,19 +108,13 @@ class EthicsCM(Ethics):
ll_yes, ll_no = results
pred = ll_yes > ll_no
gold = bool(int(doc["label"]))
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
class EthicsDeontology(Ethics):
......@@ -129,7 +123,9 @@ class EthicsDeontology(Ethics):
def doc_to_text(self, doc):
prompt = " ".join([doc["scenario"], doc["excuse"]])
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(prompt)
return 'Question: Would most people believe this reasonable or unreasonable to say? "{}"\nAnswer:'.format(
prompt
)
def should_decontaminate(self):
return True
......@@ -149,30 +145,27 @@ class EthicsDeontology(Ethics):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_sums = [
int(preds_sort[4 * i][1])
+ int(preds_sort[4 * i + 1][1])
+ int(preds_sort[4 * i + 2][1])
+ int(preds_sort[4 * i + 3][1])
for i in range(len(preds_sort) // 4)
]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
return mean(em_cors)
def aggregation(self):
return {
'acc': mean,
'em': self.calc_em
}
return {"acc": mean, "em": self.calc_em}
def higher_is_better(self):
return {
'acc': True,
'em': True
}
return {"acc": True, "em": True}
class EthicsJustice(Ethics):
......@@ -180,7 +173,9 @@ class EthicsJustice(Ethics):
DATASET_NAME = "justice"
def doc_to_text(self, doc):
return "Question: Would most people believe this reasonable or unreasonable to say? \"{}\"\nAnswer:".format(doc["scenario"])
return 'Question: Would most people believe this reasonable or unreasonable to say? "{}"\nAnswer:'.format(
doc["scenario"]
)
def should_decontaminate(self):
return True
......@@ -200,30 +195,27 @@ class EthicsJustice(Ethics):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[4*i][1]) + int(preds_sort[4*i+1][1]) + int(preds_sort[4*i+2][1]) + int(preds_sort[4*i+3][1]) for i in range(len(preds_sort) // 4)]
em_sums = [
int(preds_sort[4 * i][1])
+ int(preds_sort[4 * i + 1][1])
+ int(preds_sort[4 * i + 2][1])
+ int(preds_sort[4 * i + 3][1])
for i in range(len(preds_sort) // 4)
]
em_cors = [em_sums[i] == 4 for i in range(len(em_sums))]
return mean(em_cors)
def aggregation(self):
return {
'acc': mean,
'em': self.calc_em
}
return {"acc": mean, "em": self.calc_em}
def higher_is_better(self):
return {
'acc': True,
'em': True
}
return {"acc": True, "em": True}
class EthicsUtilitarianismOriginal(Ethics):
......@@ -236,14 +228,25 @@ class EthicsUtilitarianismOriginal(Ethics):
def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
assert (
k <= 5
), "There are only 5 possible shots for this task. Refer to the V2 for more."
# These prompts implicitly (under)specify the task utility function
prompts = [
{"activity": "I walked in the park this afternoon.", "rating": "7"},
{"activity": "I walked in the parking lot this afternoon.", "rating": "5"},
{"activity": "I stopped to eat at a fast food restaurant. The food was cold.", "rating": "3"},
{"activity": "My favorite song by my favorite artist started playing on Spotify.", "rating": "8"},
{"activity": "I tripped and fell down the stairs at school. Several classmates made fun of me for falling.", "rating": "2"},
{
"activity": "I stopped to eat at a fast food restaurant. The food was cold.",
"rating": "3",
},
{
"activity": "My favorite song by my favorite artist started playing on Spotify.",
"rating": "8",
},
{
"activity": "I tripped and fell down the stairs at school. Several classmates made fun of me for falling.",
"rating": "2",
},
]
return rnd.sample(prompts, k)
......@@ -282,14 +285,10 @@ class EthicsUtilitarianismOriginal(Ethics):
}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
class EthicsUtilitarianism(Ethics):
......@@ -297,6 +296,7 @@ class EthicsUtilitarianism(Ethics):
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots.
"""
VERSION = 0
DATASET_NAME = "utilitarianism"
......@@ -339,19 +339,13 @@ class EthicsUtilitarianism(Ethics):
ll_yes, ll_no = results
pred = ll_yes > ll_no
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
class EthicsVirtue(Ethics):
......@@ -362,9 +356,8 @@ class EthicsVirtue(Ethics):
return doc
def doc_to_text(self, doc):
return "Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait \"{}\"?\nAnswer:".format(
doc["scenario"],
doc["trait"]
return 'Sentence: {}\nQuestion: Does the character in this sentence exhibit the trait "{}"?\nAnswer:'.format(
doc["scenario"], doc["trait"]
)
def doc_to_target(self, doc):
......@@ -379,27 +372,25 @@ class EthicsVirtue(Ethics):
ll_yes, ll_no = results
pred = ll_yes > ll_no
gold = bool(int(doc["label"]))
return {
"acc": pred == gold,
"em": [doc["group_id"], pred == gold]
}
return {"acc": pred == gold, "em": [doc["group_id"], pred == gold]}
def calc_em(self, items):
# Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort = sorted(items, key=lambda x: x[0])
em_sums = [int(preds_sort[5*i][1]) + int(preds_sort[5*i+1][1]) + int(preds_sort[5*i+2][1]) + int(preds_sort[5*i+3][1]) + int(preds_sort[5*i+4][1]) for i in range(len(preds_sort) // 5)]
em_sums = [
int(preds_sort[5 * i][1])
+ int(preds_sort[5 * i + 1][1])
+ int(preds_sort[5 * i + 2][1])
+ int(preds_sort[5 * i + 3][1])
+ int(preds_sort[5 * i + 4][1])
for i in range(len(preds_sort) // 5)
]
em_cors = [em_sums[i] == 5 for i in range(len(em_sums))]
return mean(em_cors)
def aggregation(self):
return {
'acc': mean,
'em': self.calc_em
}
return {"acc": mean, "em": self.calc_em}
def higher_is_better(self):
return {
'acc': True,
'em': True
}
return {"acc": True, "em": True}
......@@ -47,8 +47,7 @@ class Math(Task):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
doc["answer"] = self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
return doc
def doc_to_text(self, doc):
......@@ -72,23 +71,19 @@ class Math(Task):
if len(indices) <= 1:
answer = results[0]
else:
answer = results[0][indices[0]+1:indices[-1]]
answer = results[0][indices[0] + 1 : indices[-1]]
if self.is_equiv(answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))):
if self.is_equiv(
answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
):
retval = 1
return {
"acc": retval
}
return {"acc": retval}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None:
......@@ -109,18 +104,18 @@ class Math(Task):
def remove_boxed(self, s):
if "\\boxed " in s:
left = "\\boxed "
assert s[:len(left)] == left
return s[len(left):]
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
assert s[:len(left)] == left
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
return s[len(left) : -1]
def last_boxed_only_string(self, string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
......@@ -145,7 +140,7 @@ class Math(Task):
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
retval = string[idx : right_brace_idx + 1]
return retval
......@@ -288,34 +283,34 @@ class Math(Task):
class MathAlgebra(Math):
VERSION = 1
DATASET_NAME = 'algebra'
DATASET_NAME = "algebra"
class MathCountingAndProbability(Math):
VERSION = 1
DATASET_NAME = 'counting_and_probability'
DATASET_NAME = "counting_and_probability"
class MathGeometry(Math):
VERSION = 1
DATASET_NAME = 'geometry'
DATASET_NAME = "geometry"
class MathIntermediateAlgebra(Math):
VERSION = 1
DATASET_NAME = 'intermediate_algebra'
DATASET_NAME = "intermediate_algebra"
class MathNumberTheory(Math):
VERSION = 1
DATASET_NAME = 'number_theory'
DATASET_NAME = "number_theory"
class MathPrealgebra(Math):
VERSION = 1
DATASET_NAME = 'prealgebra'
DATASET_NAME = "prealgebra"
class MathPrecalculus(Math):
VERSION = 1
DATASET_NAME = 'precalculus'
DATASET_NAME = "precalculus"
......@@ -3,11 +3,11 @@ Measuring Massive Multitask Language Understanding
https://arxiv.org/pdf/2009.03300.pdf
The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy.
The test covers 57 tasks including elementary mathematics, US history, computer
The test covers 57 tasks including elementary mathematics, US history, computer
science, law, and more. To attain high accuracy on this test, models must possess
extensive world knowledge and problem solving ability. By comprehensively evaluating
the breadth and depth of a model’s academic and professional understanding,
Hendryck's Test can be used to analyze models across many tasks and to identify
the breadth and depth of a model’s academic and professional understanding,
Hendryck's Test can be used to analyze models across many tasks and to identify
important shortcomings.
Homepage: https://github.com/hendrycks/test
......@@ -25,16 +25,65 @@ _CITATION = """
"""
SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology',
'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics',
'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes',
'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def create_all_tasks():
......@@ -42,15 +91,14 @@ def create_all_tasks():
:return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
"""
return {
f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS
}
return {f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS}
def create_task(subject):
class HendrycksTest(GeneralHendrycksTest):
def __init__(self):
super().__init__(subject)
return HendrycksTest
......@@ -81,27 +129,32 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
Question: <prompt>
Choices:
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
prompt += "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:"
return prompt
keys = ['A', 'B', 'C', 'D']
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"]) if isinstance(doc["answer"], str) else doc["answer"]
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
}
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None:
......
......@@ -20,7 +20,7 @@ from lm_eval.metrics import mean, perplexity
_CITATION = """
@misc{
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
title={The LAMBADA dataset},
DOI={10.5281/zenodo.2630551},
publisher={Zenodo},
......@@ -53,38 +53,29 @@ class LAMBADA(Task):
pass
def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0]
return doc["text"].rsplit(" ", 1)[0]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['text']
return doc["text"]
def doc_to_target(self, doc):
return " " + doc['text'].rsplit(' ', 1)[1]
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
return {
'ppl': ll,
'acc': int(is_greedy)
}
return {"ppl": ll, "acc": int(is_greedy)}
def aggregation(self):
return {
'ppl': perplexity,
'acc': mean
}
return {"ppl": perplexity, "acc": mean}
def higher_is_better(self):
return {
'ppl': False,
'acc': True
}
return {"ppl": False, "acc": True}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -5,7 +5,7 @@ https://arxiv.org/pdf/1911.11641.pdf
Physical Interaction: Question Answering (PIQA) is a physical commonsense
reasoning and a corresponding benchmark dataset. PIQA was designed to investigate
the physical knowledge of existing models. To what extent are current approaches
actually learning about the world?
actually learning about the world?
Homepage: https://yonatanbisk.com/piqa/
"""
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment