Commit 65c46d22 authored by Jon Tow's avatar Jon Tow
Browse files

Add `HellaSwag` evaluation implementation

parent 7031c324
...@@ -40,7 +40,7 @@ TASK_REGISTRY = { ...@@ -40,7 +40,7 @@ TASK_REGISTRY = {
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
# "hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
......
import numpy as np import numpy as np
from scipy.stats import pearsonr, spearmanr from ..base import rf, mean
from sklearn.metrics import f1_score, matthews_corrcoef from . common import HFTask
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
class HellaSwag(HFTask): class HellaSwag(HFTask):
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
...@@ -30,7 +29,9 @@ class HellaSwag(HFTask): ...@@ -30,7 +29,9 @@ class HellaSwag(HFTask):
return self.data["test"] return self.data["test"]
def fewshot_description(self): def fewshot_description(self):
return "Label for the relevant action: Sentences describing the context, with an incomplete sentence trailing\nanswer that plausibly completes the situation." return "Label for the relevant action: Sentences describing the " \
"context, with an incomplete sentence trailing\nanswer that " \
"plausibly completes the situation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['activity_label'] + ': ' + doc['ctx'] + '\n' return doc['activity_label'] + ': ' + doc['ctx'] + '\n'
...@@ -46,7 +47,8 @@ class HellaSwag(HFTask): ...@@ -46,7 +47,8 @@ class HellaSwag(HFTask):
elif letter_answer == '3': elif letter_answer == '3':
index = 3 index = 3
else: else:
raise ValueError("HellaSwag from HF datasets contained an invalid answer key") raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key")
return doc['endings'][index] return doc['endings'][index]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -60,8 +62,10 @@ class HellaSwag(HFTask): ...@@ -60,8 +62,10 @@ class HellaSwag(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_answers = [
raise NotImplementedError('Evaluation not implemented') rf.loglikelihood(ctx, doc['endings'][i])[0] for i in range(4)
]
return ll_answers
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -73,8 +77,12 @@ class HellaSwag(HFTask): ...@@ -73,8 +77,12 @@ class HellaSwag(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = int(doc['label'])
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
acc = 1. if pred == gold else 0.
return {
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
...@@ -82,8 +90,9 @@ class HellaSwag(HFTask): ...@@ -82,8 +90,9 @@ class HellaSwag(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -91,5 +100,6 @@ class HellaSwag(HFTask): ...@@ -91,5 +100,6 @@ class HellaSwag(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
\ No newline at end of file }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment