Unverified Commit 94d782a0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #97 from EleutherAI/unit-testing

Implement unit testing and fix lots of problems with tasks
parents 693c19e2 60a6fd8c
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import random
from lm_eval.base import LM
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("dummy")
class DummyLM(LM):
def __init__(self):
pass
@classmethod
def create_from_arg_string(cls, arg_string):
return cls()
def loglikelihood(self, requests):
res = []
for _ in requests:
res.append((-random.random(), False))
def loglikelihood(self, context, continuation):
return 0.0
return res
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
......@@ -19,7 +19,6 @@ class GPT2LM(LM):
return cls(device=args.get("device", "cpu"))
def loglikelihood(self, requests):
print(requests)
res = []
# TODO: vectorize properly
for context, continuation in tqdm(requests):
......
......@@ -32,6 +32,8 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation):
# TODO: implement new framework
import openai
context_enc = self.tokenizer.encode(context)
......
......@@ -23,7 +23,7 @@ TASK_REGISTRY = {
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
"stsb": glue.STSB,
#"stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
......@@ -33,23 +33,25 @@ TASK_REGISTRY = {
"multirc": superglue.MultiRC,
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
#"wsc": superglue.SGWinogradSchemaChallenge, # not implemented yet
# Order by benchmark/genre?
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
"quac": quac.QuAC,
"hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA,
"sat": sat.SATAnalogies,
"squad": squad.SQuAD,
"race": race.RACE,
"naturalqs": naturalqs.NaturalQs,
"webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273,
"winogrande": winogrande.Winogrande,
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet
# "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
# "race": race.RACE, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
# "anli_r1": anli.ANLIRound1, # not implemented yet
# "anli_r2": anli.ANLIRound2, # not implemented yet
# "anli_r3": anli.ANLIRound3, # not implemented yet
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......
......@@ -12,7 +12,6 @@ class Arithmetic(Dataset):
def __init__(self):
super().__init__()
self.set_docs()
def download(self):
file_name, checksum = self.get_file_download_info()
......@@ -20,6 +19,7 @@ class Arithmetic(Dataset):
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum)
self.set_docs()
@abc.abstractmethod
def get_file_download_info(self):
......
......@@ -11,6 +11,8 @@ class HFTask(Dataset):
def __init__(self):
super().__init__()
self._training_docs = None
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
......
......@@ -11,7 +11,7 @@ class DROP(Dataset):
DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self):
self.download()
super().__init__()
def has_training_docs(self):
"""Whether the task has a training set"""
......
......@@ -54,16 +54,13 @@ class RACE(HFTask):
# TODO: figure out description
return ""
def doc_to_text(self, doc, include_target=True):
r = "Article:\n" + doc['article'] + '\n\n'
def doc_to_text(self, doc):
# TODO: implement
pass
r += doc['problems'] >> apply(enumerate) >> each(
lambda x: 'Q: ' + x[1]['question'] + '\n\nA:'
+ ((' ' + x[1]['options'][['A', 'B', 'C', 'D'].index(x[1]['answer'])]) \
if x[0] != len(doc['problems']) - 1 or include_target else '')) \
>> join('\n\n')
return r
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......
......@@ -9,6 +9,8 @@ from ..utils import sh
class SATAnalogies(Dataset):
NEEDS_MANUAL_DL = True
def __init__(self):
super().__init__()
......
......@@ -5,8 +5,8 @@ from ..utils import sh
import csv
class StoryCloze(Dataset):
def __init__(self):
self.download()
NEEDS_MANUAL_DL = True
def download(self):
#TODO: replace with Eye link
pass
......@@ -30,7 +30,7 @@ class StoryCloze(Dataset):
def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
......
......@@ -75,6 +75,7 @@ class CommitmentBank(HFTask):
return True
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \
"to the truth of the hypothesis. The three possible labels are true, false or neither."
......@@ -145,6 +146,7 @@ class Copa(HFTask):
return True
def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and one alternative with a causal relation to the premise and another without," \
"choose the more plausible alternative"
......@@ -208,6 +210,7 @@ class MultiRC(HFTask):
return True
def fewshot_description(self):
# TODO: figure out actual description
return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc):
......@@ -260,24 +263,37 @@ class ReCoRD(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: figure out actual description
return ""
def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = []
for doc in self.data["train"]:
for entity in list(set(doc["entities"])):
self._training_docs.append({
"passage": doc["passage"],
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
})
return self._training_docs
if self._training_docs is None:
self._training_docs = []
for doc in self.data["train"]:
for entity in list(set(doc["entities"])):
self._training_docs.append({
"passage": doc["passage"],
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
})
return self._training_docs
def validation_docs(self):
for doc in self.data["validation"]:
for entity in list(set(doc["entities"])):
yield {
"passage": doc["passage"],
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
}
def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
......@@ -296,7 +312,7 @@ class ReCoRD(HFTask):
def construct_requests(self, doc, ctx):
requests = [
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
for entity in doc["entities"]
for entity in doc["entity"]
]
return requests
......@@ -342,6 +358,10 @@ class WordsInContext(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nanswer:".format(
......@@ -405,6 +425,7 @@ class SGWinogradSchemaChallenge(HFTask):
return self._training_docs
def fewshot_description(self):
# TODO: figure out actual description
return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \
......
import os
import json
import random
from lm_eval.base import Dataset
from ..utils import sh
class TriviaQA(Dataset):
def __init__(self):
self.download()
def download(self):
#pass
#TODO: don't download if files already there
sh("""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""")
if not os.path.exists('data/triviaqa'):
sh("""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""")
def has_training_docs(self):
return True
......
......@@ -71,12 +71,14 @@ class WinogradSchemaChallenge273(Dataset):
docs.append(doc)
return docs
def doc_to_text(self, doc, include_target=True):
# WSC273 is currently only writing out full examples. Partial evaluation needs implementing.
text = doc['completions']['T'] + ' True. ' + doc['completions']['F'] + ' False.'
return text
def doc_to_text(self, doc):
# TODO: implement
pass
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment