Commit 2bfa4518 authored by jon-tow's avatar jon-tow
Browse files

Fix prompt source rank choice accuracy

parent 9f388461
...@@ -644,7 +644,6 @@ class PromptSourceTask(Task): ...@@ -644,7 +644,6 @@ class PromptSourceTask(Task):
return f" {target}" return f" {target}"
def doc_to_text(self, doc): def doc_to_text(self, doc):
print(doc)
text, _ = self.prompt.apply(doc) text, _ = self.prompt.apply(doc)
return text return text
...@@ -661,13 +660,14 @@ class PromptSourceTask(Task): ...@@ -661,13 +660,14 @@ class PromptSourceTask(Task):
""" """
_requests = [] _requests = []
if self.prompt.metadata.choices_in_prompt: answer_choices_list = self.prompt.get_answer_choices_list(doc)
for answer_choice in self.prompt.get_fixed_answer_choices_list(): if answer_choices_list:
for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"]) ll_greedy = rf.greedy_until(ctx, ["\nQ:"])
_requests.append(ll_greedy) _requests.append(ll_greedy)
return _requests return _requests
...@@ -682,20 +682,22 @@ class PromptSourceTask(Task): ...@@ -682,20 +682,22 @@ class PromptSourceTask(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
raise NotImplementedError( # raise NotImplementedError(
"Implement process results using the `prompt.metadata.metrics`. See below." # "Implement process results using the `prompt.metadata.metrics`. See below."
) # )
if self.prompt.metadata.choices_in_prompt: target = self.doc_to_target(doc).strip()
for result, answer_choice in zip( answer_choices_list = self.prompt.get_answer_choices_list(doc)
prompt.get_fixed_answer_choices_list(), results if answer_choices_list:
): pred = answer_choices_list[np.argmax(results)]
pass return {
"acc": pred == target
}
else: else:
continuation = results continuation = results
# Map metric name to HF metric. # Map metric name to HF metric.
# TODO(Albert): What is Other? # TODO(Albert): What is Other?
metric_names = prompt.metadata.metrics #metric_names = prompt.metadata.metrics
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
......
...@@ -241,15 +241,12 @@ def evaluate( ...@@ -241,15 +241,12 @@ def evaluate(
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
# aggregate results # aggregate results
for (task_prompt_name, metric), items in vals.items(): for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+") task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_name] task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items) results[task_prompt_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
...@@ -276,13 +273,13 @@ def make_table(result_dict): ...@@ -276,13 +273,13 @@ def make_table(result_dict):
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = [] values = []
for k, dic in result_dict["results"].items(): for k, dic in result_dict["results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
if "_name" in m:
continue
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
......
...@@ -30,7 +30,7 @@ _CITATION = """ ...@@ -30,7 +30,7 @@ _CITATION = """
class CoQA(PromptSourceTask): class CoQA(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa) DATASET_PATH = "coqa"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
...@@ -57,7 +57,6 @@ class CoQA(PromptSourceTask): ...@@ -57,7 +57,6 @@ class CoQA(PromptSourceTask):
# answers = [] # answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1] # answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn) # answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers") # additional_answers = doc.get("additional_answers")
# if additional_answers: # if additional_answers:
# for key in additional_answers: # for key in additional_answers:
......
...@@ -14,7 +14,7 @@ respect to a wide range of linguistic phenomena found in natural language. ...@@ -14,7 +14,7 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/ Homepage: https://gluebenchmark.com/
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import PromptSourceTask, rf, Task
from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize from lm_eval.utils import general_detokenize
...@@ -286,7 +286,7 @@ class QNLI(Task): ...@@ -286,7 +286,7 @@ class QNLI(Task):
} }
class WNLI(Task): class WNLI(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -301,37 +301,14 @@ class WNLI(Task): ...@@ -301,37 +301,14 @@ class WNLI(Task):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: # if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) # self._training_docs = list()
return self._training_docs # return self._training_docs
return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
)
def doc_to_target(self, doc):
# True = entailment
# False = not_entailment
return " {}".format({0: "False", 1: "True"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_false
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
...@@ -343,7 +320,7 @@ class WNLI(Task): ...@@ -343,7 +320,7 @@ class WNLI(Task):
} }
class RTE(Task): class RTE(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -365,29 +342,13 @@ class RTE(Task): ...@@ -365,29 +342,13 @@ class RTE(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): # def process_results(self, doc, results):
return "{}\nQuestion: {} True or False?\nAnswer:".format( # ll_true, ll_false = results
doc["sentence1"], # pred = ll_false > ll_true
doc["sentence2"], # gold = doc["label"]
) # return {
# "acc": pred == gold
def doc_to_target(self, doc): # }
# 0 = entailment
# 1 = not_entailment
return " {}".format({0: "True", 1: "False"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_false
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_false > ll_true
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {
......
...@@ -51,47 +51,47 @@ class RACE(PromptSourceTask): ...@@ -51,47 +51,47 @@ class RACE(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _collate_data(self, set): # def _collate_data(self, set):
if set in self.cache: # if set in self.cache:
return self.cache[set] # return self.cache[set]
# One big issue with HF's implementation of this dataset: it makes a # # One big issue with HF's implementation of this dataset: it makes a
# separate document for each question; meanwhile, in the GPT3 paper it # # separate document for each question; meanwhile, in the GPT3 paper it
# is shown that one document is made per passage. # # is shown that one document is made per passage.
r = collections.defaultdict(list) # r = collections.defaultdict(list)
for item in datasets.load_dataset( # for item in datasets.load_dataset(
path=self.DATASET_PATH, name=self.DATASET_NAME # path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]: # )[set]:
r[item["article"]].append(item) # r[item["article"]].append(item)
res = list( # res = list(
r.values() # r.values()
>> each( # >> each(
lambda x: { # lambda x: {
"article": x[0]["article"], # "article": x[0]["article"],
"problems": x # "problems": x
>> each( # >> each(
lambda y: { # lambda y: {
"question": y["question"], # "question": y["question"],
"answer": y["answer"], # "answer": y["answer"],
"options": y["options"], # "options": y["options"],
} # }
), # ),
} # }
) # )
) # )
self.cache[set] = res # self.cache[set] = res
return res # return res
def training_docs(self): def training_docs(self):
return self._collate_data("train") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self._collate_data("validation") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return self._collate_data("test") return self.dataset["test"]
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
......
...@@ -30,7 +30,7 @@ def main(): ...@@ -30,7 +30,7 @@ def main():
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict_promptsource(task_names)
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment