Commit 2d4b3a8c authored by Jason Phang's avatar Jason Phang
Browse files

checkin

parent 12e12bc0
......@@ -18,7 +18,7 @@ class LM(abc.ABC):
@abc.abstractmethod
def loglikelihood(self, context, continuation):
"""Compute log-prob of a generation a continuation from a context
"""Compute log-likelihood of a generation a continuation from a context
Assume that the final text will simple be
context + continuation
......@@ -46,14 +46,26 @@ class LM(abc.ABC):
class Dataset(abc.ABC):
@abc.abstractmethod
def has_training_docs(self):
"""Whether the task has a training set"""
pass
@abc.abstractmethod
def has_validation_docs(self):
"""Whether the task has a validation set"""
pass
@abc.abstractmethod
def has_test_docs(self):
"""Whether the task has a test set"""
pass
@abc.abstractmethod
def training_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
pass
@abc.abstractmethod
......@@ -70,10 +82,6 @@ class Dataset(abc.ABC):
random.shuffle(traindocs)
return traindocs[:k]
@abc.abstractmethod
def fewshot_description(self):
pass
@abc.abstractmethod
def doc_to_text(self, doc, include_target=True):
......@@ -81,8 +89,29 @@ class Dataset(abc.ABC):
@abc.abstractmethod
def evaluate(self, docs, lm, provide_description, num_fewshot):
"""Take iterable of docs and evaluates, returning a dict with the following format:
{
"major": float,
"minor": dict,
"higher_is_better": bool,
}
* `major` should be a single, representative number, for programmatic comparison
* `minor` should be a dictionary containing all relevant sub-metrics
* `higher_is_better` determines whether a higher metric is better
"""
pass
def fewshot_prefix(self):
return ""
def fewshot_context(self, doc, k):
prefix = self.fewshot_prefix()
labeled_examples = "\n\n".join([self.doc_to_text(doc) for doc in self.fewshot_examples(k)])
example = self.doc_to_text(doc, include_target=False)
return prefix + labeled_examples + example
class Registry:
def __init__(self, registry_name):
......
# NLP generally do not require separately downloading data
#coqa
mkdir -p data/coqa
......
import transformers
import torch
from ..base import LM
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("dummy")
class DummyLM(LM):
def generate(self, context, max_gen_length):
return "lol"
def loglikelihood(self, context, continuation):
return 0.0
......@@ -22,5 +22,5 @@ class GPT2LM(LM):
# chop off the prompt and the final eos token
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def nll_of(self, context, continuation):
def loglikelihood(self, context, continuation):
pass
......@@ -28,7 +28,7 @@ class GPT3LM(LM):
)
return response.choices[0]["text"]
def logprob_of(self, context, continuation):
def loglikelihood(self, context, continuation):
full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context))
......
import abc
import nlp
import numpy as np
from ..base import Dataset
class NLP_TASK(Dataset):
NLP_PATH = None
NLP_NAME = None
def _load_nlp_dataset(self):
return nlp.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME)
def training_docs(self):
if self.has_training_docs():
return self._load_nlp_dataset()["train"]
def validation_docs(self):
if self.has_validation_docs():
return self._load_nlp_dataset()["validation"]
def test_docs(self):
if self.has_test_docs():
return self._load_nlp_dataset()["test"]
def simple_accuracy_metric(preds, golds):
acc = float((np.array(preds) == np.array(golds)).mean())
return {
"major": acc,
"minor": {"acc": acc},
"higher_is_better": True,
}
import nlp
import numpy as np
import random
from sklearn.metrics import f1_score, matthews_corrcoef
from . common import NLP_TASK, simple_accuracy_metric
from . import TASK_REGISTRY
@TASK_REGISTRY.register("cola")
class CoLA(NLP_TASK):
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def doc_to_text(self, doc, include_target=True):
text = "Does this sentence make sense?:\tTrue or False?" \
"\nsentence:{}\nAnswer: ".format(doc["sentence"])
if include_target:
text += " {}".format({1: "True", 0: "False"}[doc["label"]])
return text
def evaluate(self, docs, lm, k=0):
golds = [doc["label"] for doc in docs]
preds = []
for doc in docs:
word = lm.generate(
context=self.fewshot_context(doc=doc, k=k),
max_gen_length=1,
)
if word.strip() == "True":
preds.append(1)
elif word.strip() == "False":
preds.append(0)
else:
preds.append(-1)
golds = np.array(golds)
preds = np.array(preds)
mcc = float(matthews_corrcoef(y_true=golds, y_pred=preds))
return {
"major": mcc,
"minor": {"mcc": mcc},
"higher_is_better": True,
}
@TASK_REGISTRY.register("mnli")
class MNLI(NLP_TASK):
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def validation_docs(self):
if self.has_validation_docs():
return self._load_nlp_dataset()["validation_matched"]
def test_docs(self):
if self.has_test_docs():
return self._load_nlp_dataset()["test_matched"]
def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format(
doc["sentence1"],
doc["sentence2"],
)
if include_target:
# True = entailment
# False = contradiction
# Neither = neutral
text += " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]])
return text
def evaluate(self, docs, lm, k=0):
golds = [doc["label"] for doc in docs]
preds = []
for doc in docs:
word = lm.generate(
context=self.fewshot_context(doc=doc, k=k),
max_gen_length=1,
)
if word.strip() == "True":
preds.append(1)
elif word.strip() == "False":
preds.append(0)
else:
preds.append(-1)
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("rte")
class RTE(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "rte"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue or False?\nanswer:".format(
doc["sentence1"],
doc["sentence2"],
)
if include_target:
text += " {}".format({1: "True", 0: "False"}[doc["label"]])
return text
def evaluate(self, docs, lm, k=0):
golds = [doc["label"] for doc in docs]
preds = []
for doc in docs:
word = lm.generate(
context=self.fewshot_context(doc=doc, k=k),
max_gen_length=1,
)
if word.strip() == "True":
preds.append(1)
elif word.strip() == "False":
preds.append(0)
else:
preds.append(-1)
return simple_accuracy_metric(preds=preds, golds=golds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment