Unverified Commit 48358352 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge branch 'master' into race-evaluation

parents 6e1a76f0 0f536808
......@@ -59,7 +59,6 @@ class LM(abc.ABC):
class Dataset(abc.ABC):
@abc.abstractmethod
def __init__(self):
self.download()
self._traindocs = None
......
......@@ -15,6 +15,8 @@ from . import sat
from . import arithmetic
from . import lambada
from . import race
from . import piqa
TASK_REGISTRY = {
# GLUE
......@@ -40,11 +42,12 @@ TASK_REGISTRY = {
# Order by benchmark/genre?
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet
# "hellaswag": hellaswag.HellaSwag, # not implemented yet
"hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
......
import re
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from ..base import rf, mean
from . common import HFTask
class HellaSwag(HFTask):
DATASET_PATH = "hellaswag"
DATASET_NAME = None
@classmethod
def remove_brackets(cls, text):
""" Removes brackets from HellaSwag documents.
NOTE: The brackets are artifacts of the WikiHow dataset portion underlying
HellaSwag.
"""
text = re.sub('\[.*?\]', '', text)
return text
def has_training_docs(self):
return True
......@@ -30,10 +39,13 @@ class HellaSwag(HFTask):
return self.data["test"]
def fewshot_description(self):
return "Label for the relevant action: Sentences describing the context, with an incomplete sentence trailing\nanswer that plausibly completes the situation."
return "Label for the relevant action: Sentences describing the " \
"context, with an incomplete sentence trailing\nanswer that " \
"plausibly completes the situation."
def doc_to_text(self, doc):
return doc['activity_label'] + ': ' + doc['ctx'] + '\n'
text = doc['activity_label'] + ': ' + doc['ctx'] + '\n'
return self.remove_brackets(text)
def doc_to_target(self, doc):
letter_answer = doc['label']
......@@ -46,50 +58,59 @@ class HellaSwag(HFTask):
elif letter_answer == '3':
index = 3
else:
raise ValueError("HellaSwag from HF datasets contained an invalid answer key")
return doc['endings'][index]
raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index]
return self.remove_brackets(target)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
ll_answers = []
for i in range(4):
continuation = self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
gold = int(doc['label'])
pred = np.argmax(results)
acc = 1. if pred == gold else 0.
return {
"acc": acc
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
return {
"acc": True
}
import json
import random
from lm_eval.base import Dataset
from lm_eval.base import Dataset, rf, mean
from ..utils import sh
import os
class PiQA(Dataset):
def __init__(self):
self.download()
def download(self):
#pass
#TODO: don't download if files already there
sh("""
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""")
if not os.path.exists('data/piqa'):
#TODO: use best_download
sh("""
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""")
def has_training_docs(self):
return True
......@@ -25,11 +24,11 @@ class PiQA(Dataset):
return True
def has_test_docs(self):
return True
return False
def load_docs(self, textfilename, labelfilename):
if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(open(labelfilename, 'r')))
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))]
......@@ -39,62 +38,40 @@ class PiQA(Dataset):
def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
def test_docs(self):
return self.load_docs('data/piqa/piqa-test.jsonl', None)
#def test_docs(self):
# return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
#TODO: check if oa uses newline
return doc['goal'] + ' '
return doc[0]['goal']
def doc_to_target(self, doc):
rightanswer = int(doc[1][0]) + 1
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
#TODO: check if oa uses newline
rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2'])
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return ll_1, ll_2
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
ll_1, ll_2 = results
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
'acc': (ll_1 > ll_2) == (int(doc[1]) == 0)
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
'acc': mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
return {
'acc': True
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment