Commit 63001b68 authored by bzantium's avatar bzantium
Browse files

fix conflict for upstream

parents a956bc63 84ef60ee
......@@ -29,7 +29,7 @@ class MuTualBase(Task):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.mutual.mutual)
DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D']
CHOICES = ["A", "B", "C", "D"]
def has_training_docs(self):
return True
......@@ -52,6 +52,12 @@ class MuTualBase(Task):
def doc_to_text(self, doc):
return self.detokenize(doc["article"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc):
return " " + self.detokenize(doc["options"][self.CHOICES.index(doc["answers"])])
......@@ -82,26 +88,14 @@ class MuTualBase(Task):
r4_1 = np.argmax(results) == gold # r4_1 = accuracy
ranks = sorted(results, reverse=True)
r4_2 = (ranks.index(results[gold]) == 1) + r4_1
mrr = 1. / (ranks.index(results[gold]) + 1) # `+ 1` for index offset
return {
"r@1": r4_1,
"r@2": r4_2,
"mrr": mrr
}
mrr = 1.0 / (ranks.index(results[gold]) + 1) # `+ 1` for index offset
return {"r@1": r4_1, "r@2": r4_2, "mrr": mrr}
def aggregation(self):
return {
"r@1": mean,
"r@2": mean,
"mrr": mean
}
return {"r@1": mean, "r@2": mean, "mrr": mean}
def higher_is_better(self):
return {
"r@1": True,
"r@2": True,
"mrr": True
}
return {"r@1": True, "r@2": True, "mrr": True}
class MuTual(MuTualBase):
......
......@@ -61,21 +61,35 @@ class NaturalQs(Task):
return rnd.sample(self._training_docs, k)
def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A:'
return "Q: " + doc["question"]["text"] + "\n\n" + "A:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]["text"]
def doc_to_target(self, doc):
# There's a short answer and a long answer. Based on the paper, I'm using the long answer.
short_answer = doc['annotations']['short_answers'][0]['text']
long_answer_start = doc['annotations']['long_answer'][0]['start_token']
long_answer_end = doc['annotations']['long_answer'][0]['end_token']
long_answer_span = doc['document']['tokens']['token'][long_answer_start:long_answer_end]
long_answer_is_html = doc['document']['tokens']['is_html'][long_answer_start:long_answer_end]
long_answer_chars = [tok for (tok, is_html) in zip(long_answer_span, long_answer_is_html) if not is_html]
# short_answer = doc["annotations"]["short_answers"][0]["text"]
long_answer_start = doc["annotations"]["long_answer"][0]["start_token"]
long_answer_end = doc["annotations"]["long_answer"][0]["end_token"]
long_answer_span = doc["document"]["tokens"]["token"][
long_answer_start:long_answer_end
]
long_answer_is_html = doc["document"]["tokens"]["is_html"][
long_answer_start:long_answer_end
]
long_answer_chars = [
tok
for (tok, is_html) in zip(long_answer_span, long_answer_is_html)
if not is_html
]
long_answer = " ".join(long_answer_chars)
return long_answer # Replace with short_answer[0] for short answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -86,7 +100,7 @@ class NaturalQs(Task):
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -99,7 +113,7 @@ class NaturalQs(Task):
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def aggregation(self):
"""
......@@ -108,7 +122,7 @@ class NaturalQs(Task):
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self):
"""
......@@ -117,4 +131,4 @@ class NaturalQs(Task):
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
......@@ -63,3 +63,9 @@ class OpenBookQA(MultipleChoiceTask):
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
"""
PAWS-X: A Cross-lingual Adversarial Dataset for Paraphrase Identification
https://arxiv.org/abs/1908.11828
The dataset consists of 23,659 human translated PAWS evaluation pairs and
296,406 machine translated training pairs in 6 typologically distinct languages.
Examples are adapted from PAWS-Wiki
Prompt format (same as in mGPT):
"<s>" + sentence1 + ", right? " + mask + ", " + sentence2 + "</s>",
where mask is the string that matches the label:
Yes, No.
Example:
<s> The Tabaci River is a tributary of the River Leurda in Romania, right? No, The Leurda River is a tributary of the River Tabaci in Romania.</s>
Language specific prompts are translated word-by-word with Google Translate
and may differ from the ones used by mGPT and XGLM (they do not provide their prompts).
Homepage: https://github.com/google-research-datasets/paws/tree/master/pawsx
"""
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
_CITATION = """
@inproceedings{yang-etal-2019-paws,
title = "{PAWS}-{X}: A Cross-lingual Adversarial Dataset for Paraphrase Identification",
author = "Yang, Yinfei and
Zhang, Yuan and
Tar, Chris and
Baldridge, Jason",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
month = nov,
year = "2019",
address = "Hong Kong, China",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D19-1382",
doi = "10.18653/v1/D19-1382",
pages = "3687--3692",
}"""
class PAWSXBase(Task):
VERSION = 0
DATASET_PATH = "paws-x"
DATASET_NAME = None # 'en'
YES = None # 'Yes'
NO = None # 'No'
QUESTION_WORD = None # 'right'
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
# same as in mGPT paper
return (
doc["sentence1"]
+ ", "
+ self.QUESTION_WORD
+ "? [MASK], "
+ doc["sentence2"]
)
def doc_to_target(self, doc):
return " " + [self.YES, self.NO][doc["label"]]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or
test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_yes = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.YES))
ll_no = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.NO))
return ll_yes, ll_no
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
ll_yes, ll_no = results
pred = ll_yes > ll_no
true_label = doc["label"]
return {
"acc": pred == true_label,
}
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
return {
"acc": mean,
}
def higher_is_better(self):
return {"acc": True}
class PAWSX_en(PAWSXBase):
DATASET_NAME = "en"
YES = "Yes"
NO = "No"
QUESTION_WORD = "right"
class PAWSX_de(PAWSXBase):
DATASET_NAME = "de"
YES = "Ja"
NO = "Nein"
QUESTION_WORD = "richtig"
class PAWSX_fr(PAWSXBase):
DATASET_NAME = "fr"
YES = "Oui"
NO = "No"
QUESTION_WORD = "right"
class PAWSX_es(PAWSXBase):
DATASET_NAME = "es"
YES = "Sí"
NO = "No"
QUESTION_WORD = "verdad"
class PAWSX_ja(PAWSXBase):
DATASET_NAME = "ja"
YES = "はい"
NO = "いいえ"
QUESTION_WORD = "ですね"
class PAWSX_ko(PAWSXBase):
DATASET_NAME = "ko"
YES = "예"
NO = "아니요"
QUESTION_WORD = "맞죠"
class PAWSX_zh(PAWSXBase):
DATASET_NAME = "zh"
YES = "是"
NO = "不是"
QUESTION_WORD = "对吧"
LANGS = [
"en",
"de",
"es",
"fr",
"ja",
"ko",
"zh",
]
LANG_CLASSES = [
PAWSX_en,
PAWSX_de,
PAWSX_es,
PAWSX_fr,
PAWSX_ja,
PAWSX_ko,
PAWSX_zh,
]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"pawsx_{lang}"] = lang_class
return tasks
......@@ -58,3 +58,9 @@ class PiQA(MultipleChoiceTask):
def doc_to_text(self, doc):
return "Question: " + doc["goal"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["goal"]
......@@ -52,22 +52,29 @@ class PROST(MultipleChoiceTask):
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "PROST is designed to probe models in a zero-shot fashion only."
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
def _process_doc(self, doc):
out_doc = {
"query": f"{doc['context']}\nQuestion: {doc['ex_question']}\nAnswer:",
"choices": [doc['A'], doc['B'], doc['C'], doc['D']],
"gold": doc['label'],
"choices": [doc["A"], doc["B"], doc["C"], doc["D"]],
"gold": doc["label"],
}
return out_doc
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -53,16 +53,20 @@ class Pubmed_QA(Task):
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs,
doc["question"],
doc["final_decision"]
ctxs, doc["question"], doc["final_decision"]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"] + " " + "\n".join(doc["context"]["contexts"])
def doc_to_target(self, doc):
return " {}".format(doc["final_decision"])
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns
"""Uses RequestFactory to construct Requests and returns
an iterable of Requests which will be sent to the LM.
"""
ll_yes, _ = rf.loglikelihood(ctx, " yes")
......@@ -79,11 +83,7 @@ class Pubmed_QA(Task):
}
def aggregation(self):
return {
"acc" : mean
}
return {"acc": mean}
def higher_is_better(self):
return {
"acc" : True
}
return {"acc": True}
......@@ -23,7 +23,7 @@ _CITATION = """
booktitle={CLEF},
year={2013}
}
"""
""" # noqa: W605
class QA4MRE(MultipleChoiceTask):
......@@ -47,7 +47,7 @@ class QA4MRE(MultipleChoiceTask):
def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"]
out_doc = {
"source": doc["document_str"].strip().replace("\'", "'"),
"source": doc["document_str"].strip().replace("'", "'"),
"query": doc["question_str"],
"choices": choices,
"gold": int(doc["correct_answer_id"]) - 1,
......@@ -57,6 +57,12 @@ class QA4MRE(MultipleChoiceTask):
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
class QA4MRE_2011(QA4MRE):
DATASET_NAME = "2011.main.EN"
......
......@@ -214,7 +214,7 @@ class QASPER(Task):
"""
# unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer"):
return [rf.greedy_until(ctx, ["\n"])]
return [rf.greedy_until(ctx, {"until": ["\n"]})]
elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
......
......@@ -51,17 +51,34 @@ class QuAC(Task):
raise NotImplementedError("QuAC has no test docs.")
def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title']
doc["title"] = doc["title"] + " - " + doc["section_title"]
return doc
def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: '
return (
"TITLE: "
+ doc["title"]
+ "\n"
+ "PARAGRAPH: "
+ doc["paragraph"]
+ "\n\n"
+ "Q: "
+ doc["question"]
+ "\n\n"
+ "A: "
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["paragraph"]
def doc_to_target(self, doc):
return doc['answer']
return doc["answer"]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -72,7 +89,7 @@ class QuAC(Task):
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -85,7 +102,7 @@ class QuAC(Task):
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def aggregation(self):
"""
......@@ -94,7 +111,7 @@ class QuAC(Task):
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self):
"""
......@@ -103,4 +120,4 @@ class QuAC(Task):
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
......@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high"
cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self):
return True
......@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
r[item['article']].append(item)
res = list(r.values() >> each(lambda x: {
'article': x[0]['article'],
'problems': x >> each(lambda y: {
'question': y['question'],
'answer': y['answer'],
'options': y['options'],
})
}))
for item in datasets.load_dataset(
path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
r[item["article"]].append(item)
res = list(
r.values()
>> each(
lambda x: {
"article": x[0]["article"],
"problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res
return res
......@@ -85,30 +95,38 @@ class RACE(Task):
@classmethod
def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']]
return problem['options'][answer]
answer = cls.letter_to_num[problem["answer"]]
return problem["options"][answer]
@classmethod
def last_problem(cls, doc):
return doc['problems'][-1]
return doc["problems"][-1]
def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]:
if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
text = "Article: " + doc["article"] + "\n\n"
for problem in doc["problems"][:-1]:
if problem["question"][-6:] == " _ .":
text += (
problem["question"][-5:] + self.get_answer_option(problem) + "\n"
)
else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
question = "Question: " + problem["question"] + "\n"
answer = "Answer: " + self.get_answer_option(problem) + "\n"
text += question + answer
text += self.last_problem(doc)['question']
text += self.last_problem(doc)["question"]
return text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -120,8 +138,7 @@ class RACE(Task):
"""
problem = self.last_problem(doc)
ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0]
for i in range(4)
rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
]
return ll_choices
......@@ -135,11 +152,9 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold = self.letter_to_num[self.last_problem(doc)['answer']]
gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results)
return {
"acc": int(pred == gold)
}
return {"acc": int(pred == gold)}
def aggregation(self):
"""
......@@ -147,9 +162,7 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -157,6 +170,4 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
......@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask):
def _process_doc(self, doc):
return {
'source': doc['source'],
'query': doc['stem'].split(' ')[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]],
'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()),
"source": doc["source"],
"query": doc["stem"].split(" ")[:2],
"choices": [
"{} is to {}".format(*c.split(" ")[:2]) for c in doc["choices"]
],
"gold": ["a", "b", "c", "d", "e"].index(doc["solution"].strip()),
}
def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query'])
return "{} is to {} as".format(*doc["query"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + "\n" + " ".join(doc["query"])
......@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
doc["distractor3"],
doc["correct_answer"],
]
src = doc['support']
src = doc["support"]
out_doc = {
"source": src,
"query": doc['question'],
"query": doc["question"],
"choices": choices,
"gold": 3,
}
......@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask):
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
......@@ -40,7 +40,7 @@ def _squad_metric(predictions, references):
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
return _squad_metric(predictions=predictions, references=references).get(key, 0)
class SQuAD2(Task):
......@@ -49,7 +49,9 @@ class SQuAD2(Task):
DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD"
assert version.parse(datasets.__version__) >= version.parse(
"1.11.0"
), "datasets v1.11.0 or later required for SQuAD"
def has_training_docs(self):
return True
......@@ -67,18 +69,35 @@ class SQuAD2(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
return (
"Title: "
+ doc["title"]
+ "\n\n"
+ "Background: "
+ doc["context"]
+ "\n\n"
+ "Question: "
+ doc["question"]
+ "\n\n"
+ "Answer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc):
answer_list = doc['answers']['text']
answer_list = doc["answers"]["text"]
if len(answer_list) > 0:
answer = answer_list[0]
else:
answer = 'unanswerable'
answer = "unanswerable"
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -88,7 +107,7 @@ class SQuAD2(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
continuation = rf.greedy_until(ctx, {"until": ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable
......@@ -107,25 +126,46 @@ class SQuAD2(Task):
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
"id": doc["id"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}
references = {
'id': doc['id'],
'answers': doc['answers'],
"id": doc["id"],
"answers": doc["answers"],
}
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
"exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
......@@ -135,14 +175,30 @@ class SQuAD2(Task):
functions that aggregate a list of metrics
"""
return {
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}
def higher_is_better(self):
......@@ -152,12 +208,12 @@ class SQuAD2(Task):
whether a higher value of the submetric is better
"""
return {
'exact': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer
'best_exact': True, # Best exact match (with varying threshold)
'best_f1': True, # Best F1 (with varying threshold)
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}
......@@ -65,12 +65,27 @@ class StoryCloze(Task):
return self.dataset["test"]
def doc_to_text(self, doc):
return ' '.join([
return " ".join(
[
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return " ".join(
[
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
]
)
def doc_to_target(self, doc):
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
......@@ -78,7 +93,7 @@ class StoryCloze(Task):
return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -89,10 +104,7 @@ class StoryCloze(Task):
part of the document for `doc`.
"""
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in clozes]
return lls
def process_results(self, doc, results):
......@@ -106,10 +118,8 @@ class StoryCloze(Task):
The results of the requests created in construct_requests.
"""
gold = doc["answer_right_ending"] - 1
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
acc = 1.0 if np.argmax(results) == gold else 0.0
return {"acc": acc}
def aggregation(self):
"""
......@@ -117,9 +127,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -127,9 +135,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
class StoryCloze2016(StoryCloze):
......
......@@ -57,13 +57,19 @@ class BoolQ(Task):
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"]
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
return " " + yesno(doc["label"])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
......@@ -71,21 +77,15 @@ class BoolQ(Task):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return {
"acc": acc
}
return {"acc": acc}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class CommitmentBank(Task):
......@@ -123,27 +123,21 @@ class CommitmentBank(Task):
return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' True')
ll_false, _ = rf.loglikelihood(ctx, ' False')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, " False")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
return ll_true, ll_false, ll_neither
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
acc = 1. if pred == gold else 0.
acc = 1.0 if pred == gold else 0.0
return {
"acc": acc,
"f1": (pred, gold)
}
return {"acc": acc, "f1": (pred, gold)}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
@classmethod
def cb_multi_fi(cls, items):
......@@ -210,21 +204,15 @@ class Copa(Task):
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
acc = 1. if pred == gold else 0.
acc = 1.0 if pred == gold else 0.0
return {
"acc": acc
}
return {"acc": acc}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
@staticmethod
def convert_choice(choice):
......@@ -268,27 +256,21 @@ class MultiRC(Task):
true_choice = self.format_answer(answer=doc["answer"], label=True)
false_choice = self.format_answer(answer=doc["answer"], label=False)
ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}')
ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}')
ll_true_choice, _ = rf.loglikelihood(ctx, f" {true_choice}")
ll_false_choice, _ = rf.loglikelihood(ctx, f" {false_choice}")
return ll_true_choice, ll_false_choice
def process_results(self, doc, results):
ll_true_choice, ll_false_choice = results
pred = ll_true_choice > ll_false_choice
return {
"acc": (pred, doc)
}
return {"acc": (pred, doc)}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": acc_all
}
return {"acc": acc_all}
class ReCoRD(Task):
......@@ -337,7 +319,7 @@ class ReCoRD(Task):
@classmethod
def format_answer(cls, query, entity):
return f' - {query}'.replace("@placeholder", entity)
return f" - {query}".replace("@placeholder", entity)
def doc_to_target(self, doc):
# We only output the first correct entity in a doc
......@@ -359,8 +341,12 @@ class ReCoRD(Task):
prediction = doc["entities"][max_idx]
gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set)
em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set)
f1 = metric_max_over_ground_truths(
squad_metrics.compute_f1, prediction, gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, prediction, gold_label_set
)
return {
"f1": f1,
......@@ -403,19 +389,21 @@ class WordsInContext(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
return (
"Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the"
" two sentences above?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
doc["sentence1"][doc["start1"]:doc["end1"]],
doc["sentence1"][doc["start1"] : doc["end1"]],
)
)
def doc_to_target(self, doc):
return " {}".format({0: "no", 1: "yes"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
......@@ -423,21 +411,15 @@ class WordsInContext(Task):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return {
"acc": acc
}
return {"acc": acc}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class SGWinogradSchemaChallenge(Task):
......@@ -461,9 +443,7 @@ class SGWinogradSchemaChallenge(Task):
if self._training_docs is None:
# GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [
doc for doc in
self.dataset["train"]
if doc["label"]
doc for doc in self.dataset["train"] if doc["label"]
]
return self._training_docs
......@@ -473,25 +453,25 @@ class SGWinogradSchemaChallenge(Task):
def doc_to_text(self, doc):
raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[:doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:]
passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post)
pre = " ".join(raw_passage.split()[: doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :]
passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post)
noun = doc["span1_text"]
pronoun = doc["span2_text"]
text = (
f"Passage: {passage}\n"
+ f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n"
+ f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
+ "Answer:"
)
return text
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
return " " + yesno(doc["label"])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
......@@ -499,18 +479,12 @@ class SGWinogradSchemaChallenge(Task):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
return {
"acc": acc
}
return {"acc": acc}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
"""
SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference
https://arxiv.org/pdf/1808.05326.pdf
SWAG (Situations With Adversarial Generations) is an adversarial dataset
that consists of 113k multiple choice questions about grounded situations. Each
question is a video caption from LSMDC or ActivityNet Captions, with four answer
choices about what might happen next in the scene. The correct answer is the
(real) video caption for the next event in the video; the three incorrect
answers are adversarially generated and human verified, so as to fool machines
but not humans.
Homepage: https://rowanzellers.com/swag/
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@inproceedings{zellers2018swagaf,
title={SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference},
author={Zellers, Rowan and Bisk, Yonatan and Schwartz, Roy and Choi, Yejin},
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
year={2018}
}
"""
class SWAG(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "swag"
DATASET_NAME = "regular"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = {
"query": doc["startphrase"],
"choices": [doc["ending0"], doc["ending1"], doc["ending2"], doc["ending3"]],
"gold": int(doc["label"]),
}
return out_doc
def doc_to_text(self, doc):
return doc["query"]
"""
ToxiGen: A Large-Scale Machine-Generated Dataset for Adversarial and Implicit Hate Speech Detection
https://arxiv.org/abs/2203.09509
Classify input text as either hateful or not hateful.
Homepage: https://github.com/microsoft/TOXIGEN
"""
from lm_eval.base import MultipleChoiceTask
import numpy as np
import pandas as pd
_CITATION = """
@inproceedings{hartvigsen2022toxigen,
title={ToxiGen: A Large-Scale Machine-Generated Dataset for Implicit and Adversarial Hate Speech Detection},
author={Hartvigsen, Thomas and Gabriel, Saadia and Palangi, Hamid and Sap, Maarten and Ray, Dipankar and Kamar, Ece},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics},
year={2022}
}
"""
class ToxiGen(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "skg/toxigen-data"
DATASET_NAME = "annotated"
def has_training_docs(self):
return True
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
dataset = self._preprocess_dataset("train")
self._training_docs = list(map(self._process_doc, dataset))
return self._training_docs
def test_docs(self):
if self.has_test_docs():
dataset = self._preprocess_dataset("test")
return map(self._process_doc, dataset)
def _preprocess_dataset(self, split: str):
"""Preprocess the dataset into a list of (text, label) tuples."""
d = pd.DataFrame(self.dataset[split])
text = d["text"]
labels = np.round(((d["toxicity_ai"] + d["toxicity_human"]) > 5.5), 0).astype(
np.int32
)
return [[x, y] for x, y in zip(text, labels)]
def _process_doc(self, doc):
return {
"query": doc[0],
"choices": ["No", "Yes"],
"gold": doc[1],
}
def doc_to_text(self, doc):
return f"Is the following statement hateful? Respond with either Yes or No. Statement: '{doc['query']}'"
......@@ -16,6 +16,20 @@ from lm_eval import metrics
from lm_eval.base import Task, rf
from typing import List
try:
import nagisa
HAS_NAGISA = True
except ImportError:
HAS_NAGISA = False
try:
import jieba
HAS_JIEBA = True
except ImportError:
HAS_JIEBA = False
_CITATION = """
@inproceedings{post-2018-call,
......@@ -41,44 +55,65 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa
return 0
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair))
f"{dataset}-{language_pair}": create_translation_task(
dataset, language_pair, version_of(dataset, language_pair)
)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Language Specifics
########################################
def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
import jieba
if not HAS_JIEBA:
raise ImportError(
"Chinese text splitting requires the `jieba` package. "
"Please install it with:\npip install jieba"
)
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
import nagisa
if not HAS_NAGISA:
raise ImportError(
"Japanese text splitting requires the `nagisa` package. "
"Please install it with:\npip install nagisa"
)
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask):
VERSION = version
def __init__(self):
super().__init__(dataset, language_pair)
return TranslationTask
class GeneralTranslationTask(Task):
VERSION = 0
......@@ -92,8 +127,9 @@ class GeneralTranslationTask(Task):
def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically
self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
self.src_file, self.ref_file = sacrebleu.download_test_set(
self.sacrebleu_dataset, self.sacrebleu_language_pair
)
self.src_data, self.ref_data = [
[line.rstrip() for line in sacrebleu.smart_open(file)]
for file in (self.src_file, self.ref_file)
......@@ -117,10 +153,9 @@ class GeneralTranslationTask(Task):
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [{
"src": src,
"ref": ref
} for src, ref in zip(self.src_data, self.ref_data)]
return [
{"src": src, "ref": ref} for src, ref in zip(self.src_data, self.ref_data)
]
def doc_to_text(self, doc):
language_codes = self.sacrebleu_language_pair.split("-")
......@@ -128,12 +163,18 @@ class GeneralTranslationTask(Task):
tar_lang = code_to_language(language_codes[1])
return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -143,7 +184,7 @@ class GeneralTranslationTask(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
return rf.greedy_until(ctx, {"until": ["\n"]})
def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
......
......@@ -29,7 +29,7 @@ _CITATION = """
class TriviaQA(Task):
VERSION = 0
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = None
......@@ -43,10 +43,10 @@ class TriviaQA(Task):
return False
def training_docs(self):
return self.dataset['train']
return self.dataset["train"]
def validation_docs(self):
return self.dataset['validation']
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
......@@ -54,8 +54,14 @@ class TriviaQA(Task):
def doc_to_text(self, doc):
return f"Question: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc['answer']['value']
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
......@@ -69,15 +75,13 @@ class TriviaQA(Task):
def construct_requests(self, doc, ctx):
ret = []
for alias in self._remove_prefixes(doc['answer']['aliases']):
for alias in self._remove_prefixes(doc["answer"]["aliases"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
def process_results(self, doc, results):
return {
"acc": float(any(results))
}
return {"acc": float(any(results))}
def aggregation(self):
return {
......@@ -85,6 +89,4 @@ class TriviaQA(Task):
}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment