Commit 97b88570 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Remove brackets from documents

parent da016758
import re
import numpy as np import numpy as np
from ..base import rf, mean from ..base import rf, mean
from . common import HFTask from . common import HFTask
...@@ -7,6 +8,15 @@ class HellaSwag(HFTask): ...@@ -7,6 +8,15 @@ class HellaSwag(HFTask):
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
@classmethod
def remove_brackets(cls, text):
""" Removes brackets from HellaSwag documents.
NOTE: The brackets are artifacts of the WikiHow dataset portion underlying
HellaSwag.
"""
text = re.sub('\[.*?\]', '', text)
return text
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -34,7 +44,8 @@ class HellaSwag(HFTask): ...@@ -34,7 +44,8 @@ class HellaSwag(HFTask):
"plausibly completes the situation." "plausibly completes the situation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['activity_label'] + ': ' + doc['ctx'] + '\n' text = doc['activity_label'] + ': ' + doc['ctx'] + '\n'
return self.remove_brackets(text)
def doc_to_target(self, doc): def doc_to_target(self, doc):
letter_answer = doc['label'] letter_answer = doc['label']
...@@ -49,12 +60,12 @@ class HellaSwag(HFTask): ...@@ -49,12 +60,12 @@ class HellaSwag(HFTask):
else: else:
raise ValueError( raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key") "HellaSwag from HF datasets contained an invalid answer key")
return doc['endings'][index] target = doc['endings'][index]
return self.remove_brackets(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
...@@ -62,16 +73,16 @@ class HellaSwag(HFTask): ...@@ -62,16 +73,16 @@ class HellaSwag(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
ll_answers = [ ll_answers = []
rf.loglikelihood(ctx, doc['endings'][i])[0] for i in range(4) for i in range(4):
] continuation = self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers return ll_answers
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param results: :param results:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment