Commit 5bc61283 authored by jon-tow's avatar jon-tow
Browse files

Add `truthfulqa_mc` support

parent 960a0e39
...@@ -67,8 +67,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC): ...@@ -67,8 +67,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc_norm": acc_norm, "acc_norm": acc_norm,
# Bundle answers: (model_answer, model_answer_index, is_correct, question_id). # Bundle answers: (model_answer, model_answer_index, is_correct, question_id).
"answer_bundle": (doc.keys[ans], ans, is_correct, doc.id), "answer_bundle": (doc.keys[ans], ans, is_correct, doc.id),
# Bundle questions: (question_id, question, option_0, option_1, option_2, option_3)
#"question_bundle": (doc.id, doc.question, doc.options),
} }
def higher_is_better(self): def higher_is_better(self):
...@@ -76,7 +74,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC): ...@@ -76,7 +74,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
"answer_bundle": True, "answer_bundle": True,
#"question_bundle": True,
} }
def aggregation(self): def aggregation(self):
...@@ -84,9 +81,40 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC): ...@@ -84,9 +81,40 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
"answer_bundle": answer_bundle "answer_bundle": answer_bundle
#"question_bundle": question_bundle,
} }
# UNCOMMENT TO WRITE OUT THE QUESTION TABLE
# TODO: Write a function for this.
#
# def process_results(self, doc: MultipleChoiceDoc, results: typing.List):
# gold = doc.gold
# ans = np.argmax(results)
# is_correct = 1. if ans == gold else 0.
# # Normalize by completion length.
# conts = self.loglikelihood_continuation(doc)
# completion_len = np.array([float(len(i)) for i in conts])
# acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
# return {
# "acc": is_correct,
# "acc_norm": acc_norm,
# # Bundle questions: (question_id, question, option_0, option_1, option_2, option_3)
# "question_bundle": (doc.id, doc.question, doc.options),
# }
# def higher_is_better(self):
# return {
# "acc": True,
# "acc_norm": True,
# "question_bundle": True,
# }
# def aggregation(self):
# return {
# "acc": mean,
# "acc_norm": mean,
# "question_bundle": question_bundle,
# }
def answer_bundle(items): def answer_bundle(items):
""" Bundles answers into a csv file. """ """ Bundles answers into a csv file. """
...@@ -222,6 +250,7 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask): ...@@ -222,6 +250,7 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask):
]) ])
prompt += "\nAnswer:" prompt += "\nAnswer:"
return prompt return prompt
def doc_to_target(self, doc: MultipleChoiceDoc) -> str: def doc_to_target(self, doc: MultipleChoiceDoc) -> str:
return " " + doc.keys[doc.gold] return " " + doc.keys[doc.gold]
......
...@@ -25,11 +25,14 @@ import numpy as np ...@@ -25,11 +25,14 @@ import numpy as np
import sacrebleu import sacrebleu
from rouge_score import rouge_scorer, scoring from rouge_score import rouge_scorer, scoring
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.base import MultipleChoiceTask
from pathlib import Path from pathlib import Path
from best_download import download_file from best_download import download_file
from ..metrics import mean from ..metrics import mean
from datasets import load_metric from datasets import load_metric
from lm_eval.mctask_experimental import MultipleChoiceDoc
# The default QA preset prompt for all models. # The default QA preset prompt for all models.
QA_PROMPT = ( QA_PROMPT = (
...@@ -48,7 +51,7 @@ QA_PROMPT = ( ...@@ -48,7 +51,7 @@ QA_PROMPT = (
) )
class TruthfulQAMultipleChoice(Task): class TruthfulQAMultipleChoice(MultipleChoiceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/mc') DATASET_PATH = Path('data/truthfulqa/mc')
...@@ -69,22 +72,33 @@ class TruthfulQAMultipleChoice(Task): ...@@ -69,22 +72,33 @@ class TruthfulQAMultipleChoice(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _convert_standard(self, doc):
question = doc["question"]
options = list(doc['mc1_targets'].keys())
# There can be >= 4 option keys.
KEY_LIST = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"]
keys = KEY_LIST[:len(options)]
# The gold answers in `mc1_targets` are always first (index = `0`).
gold = 0
return MultipleChoiceDoc(
question=question,
options=options,
gold=gold,
keys=keys,
)
def training_docs(self): def training_docs(self):
raise NotImplementedError() raise NotImplementedError()
def validation_docs(self): def validation_docs(self):
with open(self.DATASET_PATH / "mc_task.json") as f: with open(self.DATASET_PATH / "mc_task.json") as f:
return json.load(f) data = json.load(f)
for doc in data:
yield self._convert_standard(doc)
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc):
return QA_PROMPT + "\n\nQ: " + doc['question'] + "\nA:"
def doc_to_target(self, doc):
return " "
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
...@@ -94,66 +108,6 @@ class TruthfulQAMultipleChoice(Task): ...@@ -94,66 +108,6 @@ class TruthfulQAMultipleChoice(Task):
description=description description=description
) )
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
def get_lls(targets):
return [rf.loglikelihood(ctx, " " + t)[0] for t in targets]
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
return get_lls(doc['mc1_targets']) + get_lls(doc['mc2_targets'])
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
def mc1(lls):
# The gold answers in `mc1_targets` are always first (index = `0`).
return np.argmax(lls) == 0
def mc2(lls):
# Split on the first `0` as everything before it is true (`1`).
split_idx = list(doc['mc2_targets'].values()).index(0)
# Compute the normalized probability mass for the correct answer.
ll_true, ll_false = lls[:split_idx], lls[split_idx:]
p_true, p_false = np.exp(np.array(ll_true)), np.exp(np.array(ll_false))
p_true = p_true / (sum(p_true) + sum(p_false))
return sum(p_true)
split_idx = len(doc['mc1_targets'])
mc1_lls, mc2_lls = results[:split_idx], results[split_idx:]
return {
"mc1": mc1(mc1_lls),
"mc2": mc2(mc2_lls)
}
def aggregation(self):
return {
"mc1": mean,
"mc2": mean
}
def higher_is_better(self):
return {
"mc1": True,
"mc2": True
}
class TruthfulQAGeneration(Task): class TruthfulQAGeneration(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path('data/truthfulqa/generation') DATASET_PATH = Path('data/truthfulqa/generation')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment