"vscode:/vscode.git/clone" did not exist on "889809456156039a715182720ac0e589ae6e12f2"
Unverified Commit 94d782a0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #97 from EleutherAI/unit-testing

Implement unit testing and fix lots of problems with tasks
parents 693c19e2 60a6fd8c
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted. import random
from lm_eval.base import LM from lm_eval.base import LM
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self):
pass
@classmethod
def create_from_arg_string(cls, arg_string):
return cls()
def loglikelihood(self, requests):
res = []
for _ in requests:
res.append((-random.random(), False))
def loglikelihood(self, context, continuation): return res
return 0.0
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
...@@ -19,7 +19,6 @@ class GPT2LM(LM): ...@@ -19,7 +19,6 @@ class GPT2LM(LM):
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
print(requests)
res = [] res = []
# TODO: vectorize properly # TODO: vectorize properly
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
......
...@@ -32,6 +32,8 @@ class GPT3LM(LM): ...@@ -32,6 +32,8 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
# TODO: implement new framework
import openai import openai
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
......
...@@ -23,7 +23,7 @@ TASK_REGISTRY = { ...@@ -23,7 +23,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
"stsb": glue.STSB, #"stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -33,23 +33,25 @@ TASK_REGISTRY = { ...@@ -33,23 +33,25 @@ TASK_REGISTRY = {
"multirc": superglue.MultiRC, "multirc": superglue.MultiRC,
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, #"wsc": superglue.SGWinogradSchemaChallenge, # not implemented yet
# Order by benchmark/genre? # Order by benchmark/genre?
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, # "arc_easy": arc.ARCEasy, # not implemented yet
"quac": quac.QuAC, # "arc_challenge": arc.ARCChallenge, # not implemented yet
"hellaswag": hellaswag.HellaSwag, # "quac": quac.QuAC, # not implemented yet
"openbookqa": openbookqa.OpenBookQA, # "hellaswag": hellaswag.HellaSwag, # not implemented yet
"sat": sat.SATAnalogies, # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
"squad": squad.SQuAD, # "sat": sat.SATAnalogies, # not implemented yet
"race": race.RACE, # "squad": squad.SQuAD, # not implemented yet
"naturalqs": naturalqs.NaturalQs, # "race": race.RACE, # not implemented yet
"webqs": webqs.WebQs, # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"wsc273": wsc273.WinogradSchemaChallenge273, # "webqs": webqs.WebQs, # not implemented yet
"winogrande": winogrande.Winogrande, # "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
"anli_r1": anli.ANLIRound1, # "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r2": anli.ANLIRound2, # "anli_r1": anli.ANLIRound1, # not implemented yet
"anli_r3": anli.ANLIRound3, # "anli_r2": anli.ANLIRound2, # not implemented yet
# "anli_r3": anli.ANLIRound3, # not implemented yet
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......
...@@ -12,7 +12,6 @@ class Arithmetic(Dataset): ...@@ -12,7 +12,6 @@ class Arithmetic(Dataset):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.set_docs()
def download(self): def download(self):
file_name, checksum = self.get_file_download_info() file_name, checksum = self.get_file_download_info()
...@@ -20,6 +19,7 @@ class Arithmetic(Dataset): ...@@ -20,6 +19,7 @@ class Arithmetic(Dataset):
if not os.path.exists(self.directory): if not os.path.exists(self.directory):
os.makedirs(self.directory) os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum) download_file(url, self.directory+file_name, checksum)
self.set_docs()
@abc.abstractmethod @abc.abstractmethod
def get_file_download_info(self): def get_file_download_info(self):
......
...@@ -11,6 +11,8 @@ class HFTask(Dataset): ...@@ -11,6 +11,8 @@ class HFTask(Dataset):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._training_docs = None self._training_docs = None
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME) self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self): def has_training_docs(self):
......
...@@ -11,7 +11,7 @@ class DROP(Dataset): ...@@ -11,7 +11,7 @@ class DROP(Dataset):
DATAFOLDER = Path(__file__).parent / "../../data/drop" DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self): def __init__(self):
self.download() super().__init__()
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
......
...@@ -54,16 +54,13 @@ class RACE(HFTask): ...@@ -54,16 +54,13 @@ class RACE(HFTask):
# TODO: figure out description # TODO: figure out description
return "" return ""
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc):
r = "Article:\n" + doc['article'] + '\n\n' # TODO: implement
pass
r += doc['problems'] >> apply(enumerate) >> each( def doc_to_target(self, doc):
lambda x: 'Q: ' + x[1]['question'] + '\n\nA:' # TODO: implement
+ ((' ' + x[1]['options'][['A', 'B', 'C', 'D'].index(x[1]['answer'])]) \ pass
if x[0] != len(doc['problems']) - 1 or include_target else '')) \
>> join('\n\n')
return r
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
......
...@@ -9,6 +9,8 @@ from ..utils import sh ...@@ -9,6 +9,8 @@ from ..utils import sh
class SATAnalogies(Dataset): class SATAnalogies(Dataset):
NEEDS_MANUAL_DL = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -5,8 +5,8 @@ from ..utils import sh ...@@ -5,8 +5,8 @@ from ..utils import sh
import csv import csv
class StoryCloze(Dataset): class StoryCloze(Dataset):
def __init__(self): NEEDS_MANUAL_DL = True
self.download()
def download(self): def download(self):
#TODO: replace with Eye link #TODO: replace with Eye link
pass pass
...@@ -30,7 +30,7 @@ class StoryCloze(Dataset): ...@@ -30,7 +30,7 @@ class StoryCloze(Dataset):
def validation_docs(self): def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv") return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self): def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv") return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
......
...@@ -75,6 +75,7 @@ class CommitmentBank(HFTask): ...@@ -75,6 +75,7 @@ class CommitmentBank(HFTask):
return True return True
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \ return "Given a premise and a hypothesis, classify whether the author of the premise is committed" \
"to the truth of the hypothesis. The three possible labels are true, false or neither." "to the truth of the hypothesis. The three possible labels are true, false or neither."
...@@ -145,6 +146,7 @@ class Copa(HFTask): ...@@ -145,6 +146,7 @@ class Copa(HFTask):
return True return True
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description
return "Given a premise and one alternative with a causal relation to the premise and another without," \ return "Given a premise and one alternative with a causal relation to the premise and another without," \
"choose the more plausible alternative" "choose the more plausible alternative"
...@@ -208,6 +210,7 @@ class MultiRC(HFTask): ...@@ -208,6 +210,7 @@ class MultiRC(HFTask):
return True return True
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description
return "READING COMPREHENSION ANSWER KEY" return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -260,24 +263,37 @@ class ReCoRD(HFTask): ...@@ -260,24 +263,37 @@ class ReCoRD(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: figure out actual description
return ""
def training_docs(self): def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we one "doc" for each (context + passage, answer) pair. # Hence, we one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing # Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once). # (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if self.has_training_docs(): if self._training_docs is None:
if self._training_docs is None: self._training_docs = []
self._training_docs = [] for doc in self.data["train"]:
for doc in self.data["train"]: for entity in list(set(doc["entities"])):
for entity in list(set(doc["entities"])): self._training_docs.append({
self._training_docs.append({ "passage": doc["passage"],
"passage": doc["passage"], "query": doc["query"],
"query": doc["query"], "entity": entity,
"entity": entity, "label": entity in doc["answers"],
"label": entity in doc["answers"], })
}) return self._training_docs
return self._training_docs
def validation_docs(self):
for doc in self.data["validation"]:
for entity in list(set(doc["entities"])):
yield {
"passage": doc["passage"],
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
...@@ -296,7 +312,7 @@ class ReCoRD(HFTask): ...@@ -296,7 +312,7 @@ class ReCoRD(HFTask):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
requests = [ requests = [
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
for entity in doc["entities"] for entity in doc["entity"]
] ]
return requests return requests
...@@ -342,6 +358,10 @@ class WordsInContext(HFTask): ...@@ -342,6 +358,10 @@ class WordsInContext(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
# TODO: figure out actual description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Is the word '{}' used in the same way in the" \ return "{}\n{}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nanswer:".format( " two sentences above?\nanswer:".format(
...@@ -405,6 +425,7 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -405,6 +425,7 @@ class SGWinogradSchemaChallenge(HFTask):
return self._training_docs return self._training_docs
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description
return "Final Exam with Answer Key\n" \ return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \ "Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \ "For each passage, you must identify which noun the pronoun marked in *bold*" \
......
import os
import json import json
import random import random
from lm_eval.base import Dataset from lm_eval.base import Dataset
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
#pass if not os.path.exists('data/triviaqa'):
#TODO: don't download if files already there sh("""
sh(""" mkdir -p data/triviaqa
mkdir -p data/triviaqa wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz mv triviaqa-unfiltered/ data/triviaqa/
mv triviaqa-unfiltered/ data/triviaqa/ """)
""")
def has_training_docs(self): def has_training_docs(self):
return True return True
......
...@@ -71,12 +71,14 @@ class WinogradSchemaChallenge273(Dataset): ...@@ -71,12 +71,14 @@ class WinogradSchemaChallenge273(Dataset):
docs.append(doc) docs.append(doc)
return docs return docs
def doc_to_text(self, doc, include_target=True):
# WSC273 is currently only writing out full examples. Partial evaluation needs implementing.
text = doc['completions']['T'] + ' True. ' + doc['completions']['F'] + ' False.'
return text
def doc_to_text(self, doc):
# TODO: implement
pass
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment