Unverified Commit 1050109b authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #136 from jon-tow/openbookqa-evaluation

Implement `OpenBookQA` evaluation
parents 359114fd aa1d7293
...@@ -56,7 +56,7 @@ TASK_REGISTRY = { ...@@ -56,7 +56,7 @@ TASK_REGISTRY = {
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
"race": race.RACE, "race": race.RACE,
......
import numpy as np from lm_eval.base import MultipleChoiceTask
from scipy.stats import pearsonr, spearmanr from .common import HFTask
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
class OpenBookQA(HFTask):
class OpenBookQA(HFTask, MultipleChoiceTask):
DATASET_PATH = "openbookqa" DATASET_PATH = "openbookqa"
DATASET_NAME = "main" DATASET_NAME = "main"
...@@ -17,82 +15,34 @@ class OpenBookQA(HFTask): ...@@ -17,82 +15,34 @@ class OpenBookQA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc):
out_doc = {
"id": doc["id"],
"query": doc["question_stem"],
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D"].index(doc["answerKey"].strip()),
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self): def training_docs(self):
if self.has_training_docs(): docs = super().training_docs()
if self._training_docs is None: return self._load_docs(docs)
self._training_docs = list(self.data["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): docs = super().validation_docs()
return self.data["validation"] return self._load_docs(docs)
def test_docs(self): def test_docs(self):
if self.has_test_docs(): docs = super().test_docs()
return self.data["test"] return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['question_stem'] + '\n' return doc["query"]
def doc_to_target(self, doc):
letter_answer = doc['answerKey']
if letter_answer == 'A':
index = 0
elif letter_answer == 'B':
index = 1
elif letter_answer == 'C':
index = 2
elif letter_answer == 'D':
index = 3
else:
raise ValueError("OpenBookQA from HF datasets contained an invalid answer key")
return doc['choices']['text'][index] + '.'
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment