Unverified Commit 48358352 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge branch 'master' into race-evaluation

parents 6e1a76f0 0f536808
...@@ -59,7 +59,6 @@ class LM(abc.ABC): ...@@ -59,7 +59,6 @@ class LM(abc.ABC):
class Dataset(abc.ABC): class Dataset(abc.ABC):
@abc.abstractmethod
def __init__(self): def __init__(self):
self.download() self.download()
self._traindocs = None self._traindocs = None
......
...@@ -15,6 +15,8 @@ from . import sat ...@@ -15,6 +15,8 @@ from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race from . import race
from . import piqa
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
...@@ -40,11 +42,12 @@ TASK_REGISTRY = { ...@@ -40,11 +42,12 @@ TASK_REGISTRY = {
# Order by benchmark/genre? # Order by benchmark/genre?
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
# "arc_easy": arc.ARCEasy, # not implemented yet # "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet # "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
# "hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
......
import re
import numpy as np import numpy as np
from scipy.stats import pearsonr, spearmanr from ..base import rf, mean
from sklearn.metrics import f1_score, matthews_corrcoef from . common import HFTask
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
class HellaSwag(HFTask): class HellaSwag(HFTask):
DATASET_PATH = "hellaswag" DATASET_PATH = "hellaswag"
DATASET_NAME = None DATASET_NAME = None
@classmethod
def remove_brackets(cls, text):
""" Removes brackets from HellaSwag documents.
NOTE: The brackets are artifacts of the WikiHow dataset portion underlying
HellaSwag.
"""
text = re.sub('\[.*?\]', '', text)
return text
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -30,10 +39,13 @@ class HellaSwag(HFTask): ...@@ -30,10 +39,13 @@ class HellaSwag(HFTask):
return self.data["test"] return self.data["test"]
def fewshot_description(self): def fewshot_description(self):
return "Label for the relevant action: Sentences describing the context, with an incomplete sentence trailing\nanswer that plausibly completes the situation." return "Label for the relevant action: Sentences describing the " \
"context, with an incomplete sentence trailing\nanswer that " \
"plausibly completes the situation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['activity_label'] + ': ' + doc['ctx'] + '\n' text = doc['activity_label'] + ': ' + doc['ctx'] + '\n'
return self.remove_brackets(text)
def doc_to_target(self, doc): def doc_to_target(self, doc):
letter_answer = doc['label'] letter_answer = doc['label']
...@@ -46,50 +58,59 @@ class HellaSwag(HFTask): ...@@ -46,50 +58,59 @@ class HellaSwag(HFTask):
elif letter_answer == '3': elif letter_answer == '3':
index = 3 index = 3
else: else:
raise ValueError("HellaSwag from HF datasets contained an invalid answer key") raise ValueError(
return doc['endings'][index] "HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index]
return self.remove_brackets(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_answers = []
raise NotImplementedError('Evaluation not implemented') for i in range(4):
continuation = self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = int(doc['label'])
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
acc = 1. if pred == gold else 0.
return {
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
\ No newline at end of file }
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset, rf, mean
from ..utils import sh from ..utils import sh
import os
class PiQA(Dataset): class PiQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
#pass if not os.path.exists('data/piqa'):
#TODO: don't download if files already there #TODO: use best_download
sh(""" sh("""
mkdir -p data/piqa mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""") """)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -25,11 +24,11 @@ class PiQA(Dataset): ...@@ -25,11 +24,11 @@ class PiQA(Dataset):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def load_docs(self, textfilename, labelfilename): def load_docs(self, textfilename, labelfilename):
if labelfilename != None: if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(open(labelfilename, 'r'))) return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
else: else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))] return [json.loads(entry) for entry in list(open(textfilename,'r'))]
...@@ -39,62 +38,40 @@ class PiQA(Dataset): ...@@ -39,62 +38,40 @@ class PiQA(Dataset):
def validation_docs(self): def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst') return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
def test_docs(self): #def test_docs(self):
return self.load_docs('data/piqa/piqa-test.jsonl', None) # return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
#TODO: check if oa uses newline return doc[0]['goal']
return doc['goal'] + ' '
def doc_to_target(self, doc): def doc_to_target(self, doc):
rightanswer = int(doc[1][0]) + 1 #TODO: check if oa uses newline
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]]) rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
Requests which will be sent to the LM. ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2'])
:param doc: return ll_1, ll_2
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a ll_1, ll_2 = results
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc: return {
The document as returned from training_docs, validation_docs, or test_docs. 'acc': (ll_1 > ll_2) == (int(doc[1]) == 0)
:param results: }
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} 'acc': mean
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} 'acc': True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment