Unverified Commit f63bc658 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #142 from jon-tow/hellaswag-refactor

Refactor `HellaSwag` as a `MultipleChoiceTask`
parents e8f9dc71 a2b108b9
import re
import numpy as np
from ..base import rf, mean
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
class HellaSwag(HFTask):
class HellaSwag(HFTask, MultipleChoiceTask):
DATASET_PATH = "hellaswag"
DATASET_NAME = None
@classmethod
def remove_brackets(cls, text):
""" Removes brackets from HellaSwag documents.
NOTE: The brackets are artifacts of the WikiHow dataset portion underlying
HellaSwag.
"""
text = re.sub('\[.*?\]', '', text)
return text
def has_training_docs(self):
return True
......@@ -24,19 +14,37 @@ class HellaSwag(HFTask):
return True
def has_test_docs(self):
return True
return False
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub('\\[.*?\\]', '', text)
text = text.replace(" ", " ")
return text
def _convert_standard(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc['activity_label'] + ': ' + ctx),
"choices": [self.preprocess(ending) for ending in doc['endings']],
"gold": int(doc['label']),
}
return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
if self.has_training_docs():
return self.data["train"]
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
def test_docs(self):
if self.has_test_docs():
return self.data["test"]
docs = super().validation_docs()
return self._load_docs(docs)
def fewshot_description(self):
return "Label for the relevant action: Sentences describing the " \
......@@ -44,73 +52,4 @@ class HellaSwag(HFTask):
"plausibly completes the situation."
def doc_to_text(self, doc):
text = doc['activity_label'] + ': ' + doc['ctx'] + '\n'
return self.remove_brackets(text)
def doc_to_target(self, doc):
letter_answer = doc['label']
if letter_answer == '0':
index = 0
elif letter_answer == '1':
index = 1
elif letter_answer == '2':
index = 2
elif letter_answer == '3':
index = 3
else:
raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index]
return " " + self.remove_brackets(target)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_answers = []
for i in range(4):
continuation = " " + self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = int(doc['label'])
pred = np.argmax(results)
acc = 1. if pred == gold else 0.
return {
"acc": acc
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return doc["query"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment