Commit 8161c22e authored by Jason Phang's avatar Jason Phang
Browse files

SuperGLUE, and truncation

parent e7a87e71
...@@ -31,6 +31,16 @@ class LM(abc.ABC): ...@@ -31,6 +31,16 @@ class LM(abc.ABC):
""" """
pass pass
@classmethod
def num_tokens(cls, string):
"""Return the number of tokens in a string, based on tokenization
:param string: str
Input string
:return: int
"""
pass
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
"""Constructor method, in case models need additional arguments """Constructor method, in case models need additional arguments
...@@ -106,7 +116,7 @@ class Dataset(abc.ABC): ...@@ -106,7 +116,7 @@ class Dataset(abc.ABC):
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description() raw_description = self.fewshot_description()
description = (raw_description + "\n\n") if provide_description and raw_description else "" description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
labeled_examples = "\n\n".join( labeled_examples = "\n\n".join(
map(self.doc_to_text, self.fewshot_examples(k=num_fewshot)) map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
) + "\n\n" ) + "\n\n"
......
...@@ -18,7 +18,7 @@ class GPT2LM(LM): ...@@ -18,7 +18,7 @@ class GPT2LM(LM):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length): def generate(self, context, max_gen_length, truncate=True):
context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long).to(self.device) context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long).to(self.device)
res = self.gpt2.generate( res = self.gpt2.generate(
context, context,
...@@ -30,11 +30,14 @@ class GPT2LM(LM): ...@@ -30,11 +30,14 @@ class GPT2LM(LM):
# chop off the prompt and the final eos token # chop off the prompt and the final eos token
return self.tokenizer.decode(res[0][len(context[0]):-1]).strip() return self.tokenizer.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation, truncate=True):
inp = torch.tensor([self.tokenizer.encode(context + continuation)], dtype=torch.long).to(self.device) inp = torch.tensor([self.tokenizer.encode(context + continuation)], dtype=torch.long).to(self.device)
ctxlen = len(self.tokenizer.encode(context.strip())) ctxlen = len(self.tokenizer.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
import os import os
import openai
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
...@@ -8,9 +7,22 @@ from . import MODEL_REGISTRY ...@@ -8,9 +7,22 @@ from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt3") @MODEL_REGISTRY.register("gpt3")
class GPT3LM(LM): class GPT3LM(LM):
def __init__(self, engine):
MAX_LENGTH = 2048
def __init__(self, engine, truncate=False):
"""
:param engine: str
OpenAI API engine (e.g. davinci)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
...@@ -20,23 +32,34 @@ class GPT3LM(LM): ...@@ -20,23 +32,34 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def generate(self, context, max_gen_length): def generate(self, context, max_gen_length):
import openai
if self.truncate:
prompt = self.smart_truncate(context, buffer=max_gen_length)
else:
prompt = context
response = openai.Completion.create( response = openai.Completion.create(
engine=self.engine, engine=self.engine,
prompt=context, prompt=prompt,
max_tokens=max_gen_length, max_tokens=max_gen_length,
temperature=0.0, temperature=0.0,
) )
return response.choices[0]["text"] return response.choices[0]["text"]
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
import openai
full_text = context + continuation full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text)) full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context)) context_length = len(self.tokenizer.tokenize(context))
continuation_length = len(self.tokenizer.tokenize(continuation)) continuation_length = len(self.tokenizer.tokenize(continuation))
assert full_text_length == context_length + continuation_length assert full_text_length == context_length + continuation_length
if self.truncate:
prompt = self.smart_truncate(full_text, buffer=0)
else:
prompt = full_text
response = openai.Completion.create( response = openai.Completion.create(
engine=self.engine, engine=self.engine,
prompt=full_text, prompt=prompt,
echo=True, echo=True,
max_tokens=0, temperature=0.0, max_tokens=0, temperature=0.0,
logprobs=0, logprobs=0,
...@@ -44,3 +67,13 @@ class GPT3LM(LM): ...@@ -44,3 +67,13 @@ class GPT3LM(LM):
logprobs = response.choices[0]["logprobs"]["token_logprobs"] logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[-continuation_length:] continuation_logprobs = logprobs[-continuation_length:]
return sum(continuation_logprobs) return sum(continuation_logprobs)
def smart_truncate(self, string, buffer=1):
tokens = self.tokenizer.tokenize(string)
available_length = self.MAX_LENGTH - 1 - buffer # OpenAI adds 1 token
kept_tokens = tokens[-available_length:]
new_string = self.tokenizer.convert_tokens_to_string(kept_tokens)
return new_string
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
...@@ -8,12 +8,20 @@ class NLP_TASK(Dataset): ...@@ -8,12 +8,20 @@ class NLP_TASK(Dataset):
NLP_PATH = None NLP_PATH = None
NLP_NAME = None NLP_NAME = None
def __init__(self):
super().__init__()
self._training_docs = None
def _load_nlp_dataset(self): def _load_nlp_dataset(self):
return nlp.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME) return nlp.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME)
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs(): if self.has_training_docs():
return self._load_nlp_dataset()["train"] if self._training_docs is None:
self._training_docs = list(self._load_nlp_dataset()["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
......
...@@ -11,7 +11,10 @@ class CoQA(Dataset): ...@@ -11,7 +11,10 @@ class CoQA(Dataset):
def has_validation_docs(self): def has_validation_docs(self):
return False return False
def has_test_docs(self):
return False
def training_docs(self): def training_docs(self):
myjson = json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] myjson = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
return self.load_doc(myjson) return self.load_doc(myjson)
......
...@@ -41,7 +41,7 @@ class CoLA(NLP_TASK): ...@@ -41,7 +41,7 @@ class CoLA(NLP_TASK):
return "Does this sentence make sense?:\tTrue or False?" return "Does this sentence make sense?:\tTrue or False?"
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = "\nSentence:{}\nAnswer: ".format(doc["sentence"]) text = "Sentence: {}\nAnswer:".format(doc["sentence"])
if include_target: if include_target:
text += " {}".format({1: "True", 0: "False"}[doc["label"]]) text += " {}".format({1: "True", 0: "False"}[doc["label"]])
return text return text
...@@ -153,7 +153,7 @@ class MRPC(NLP_TASK): ...@@ -153,7 +153,7 @@ class MRPC(NLP_TASK):
provide_description=provide_description, provide_description=provide_description,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
) )
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no')) preds.append(lm.loglikelihood(ctx, 'yes') > lm.loglikelihood(ctx, 'no'))
return get_accuracy_and_f1(preds=preds, golds=golds) return get_accuracy_and_f1(preds=preds, golds=golds)
...@@ -210,14 +210,14 @@ class QNLI(NLP_TASK): ...@@ -210,14 +210,14 @@ class QNLI(NLP_TASK):
return True return True
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue or False?\nanswer:".format( text = "question:\t{}\nresponse:\t{}\nDoes this answer the question, Yes or No?:".format(
doc["question"], doc["question"],
doc["sentence"], doc["sentence"],
) )
if include_target: if include_target:
# True = entailment # True = entailment
# False = not entailment # False = not entailment
text += " {}".format({0: "True", 1: "False"}[doc["label"]]) text += " {}".format({0: "Yes", 1: "No"}[doc["label"]])
return text return text
def evaluate(self, docs, lm, provide_description, num_fewshot): def evaluate(self, docs, lm, provide_description, num_fewshot):
...@@ -248,7 +248,7 @@ class QQP(NLP_TASK): ...@@ -248,7 +248,7 @@ class QQP(NLP_TASK):
return True return True
def fewshot_description(self): def fewshot_description(self):
return "Indicate if both sentences mean the same thing." return "Indicate if both questions ask the same thing."
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = "question 1:\t{}\nquestion 2:\t{}\nanswer:".format( text = "question 1:\t{}\nquestion 2:\t{}\nanswer:".format(
...@@ -296,7 +296,7 @@ class STSB(NLP_TASK): ...@@ -296,7 +296,7 @@ class STSB(NLP_TASK):
doc["sentence2"], doc["sentence2"],
) )
if include_target: if include_target:
text += " {}".format(yesno(doc["label"])) text += " {}".format(doc["label"])
return text return text
def evaluate(self, docs, lm, provide_description, num_fewshot): def evaluate(self, docs, lm, provide_description, num_fewshot):
...@@ -314,6 +314,7 @@ class STSB(NLP_TASK): ...@@ -314,6 +314,7 @@ class STSB(NLP_TASK):
pred = max(min(float(first_element), 5.0), 0.0) pred = max(min(float(first_element), 5.0), 0.0)
else: else:
pred = 2.5 pred = 2.5
import pdb; pdb.set_trace()
preds.append(pred) preds.append(pred)
pearson_corr = float(pearsonr(preds, golds)[0]) pearson_corr = float(pearsonr(preds, golds)[0])
spearman_corr = float(spearmanr(preds, golds)[0]) spearman_corr = float(spearmanr(preds, golds)[0])
...@@ -383,8 +384,8 @@ class WNLI(NLP_TASK): ...@@ -383,8 +384,8 @@ class WNLI(NLP_TASK):
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format( text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format(
doc["premise"], doc["sentence1"],
doc["hypothesis"], doc["sentence2"],
) )
if include_target: if include_target:
# True = entailment # True = entailment
......
import numpy as np
from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno from . common import NLP_TASK, simple_accuracy_metric, yesno
from . import TASK_REGISTRY from . import TASK_REGISTRY
...@@ -34,3 +36,195 @@ class BoolQ(NLP_TASK): ...@@ -34,3 +36,195 @@ class BoolQ(NLP_TASK):
) )
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no')) preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no'))
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("cb")
class CommitmentBank(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "cb"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\ttrue, false or neither?\nanswer:".format(
doc["premise"],
doc["hypothesis"],
)
if include_target:
# True = entailment
# False = contradiction
# Neither = neutral
text += " {}".format({0: "true", 1: "neither", 2: "false"}[doc["label"]])
return text
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
probs = np.array([
lm.loglikelihood(ctx, ' true'),
lm.loglikelihood(ctx, ' neither'),
lm.loglikelihood(ctx, ' false'),
])
preds.append(np.argmax(probs))
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("copa")
class Copa(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "copa"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def doc_to_text(self, doc, include_target=True):
# Drop the period
text = doc["premise"].strip()[:-1] + " because "
if include_target:
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences
text += self.convert_choice(correct_choice)
return text
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
choice1 = " " + self.convert_choice(doc["choice1"])
choice2 = " " + self.convert_choice(doc["choice2"])
preds.append(lm.loglikelihood(ctx, choice2) > lm.loglikelihood(ctx, choice1))
return simple_accuracy_metric(preds=preds, golds=golds)
@staticmethod
def convert_choice(choice):
return choice[0].lower() + choice[1:]
@TASK_REGISTRY.register("wic")
class WordsInContext(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wic"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def doc_to_text(self, doc, include_target=True):
text = "{}\n{}\nquestion\tIs the word '{}' used in the same way in the" \
" two sentences above?\nanswer:".format(
doc["sentence1"],
doc["sentence2"],
doc["sentence1"][doc["start1"]:doc["end1"]],
)
if include_target:
text += " {}".format({0: "no", 1: "yes"}[doc["label"]])
return text
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no'))
return simple_accuracy_metric(preds=preds, golds=golds)
@TASK_REGISTRY.register("wsc")
class WinogradSchemaChallenge(NLP_TASK):
NLP_PATH = "super_glue"
NLP_NAME = "wsc"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
# GPT-3 Paper's format only uses positive examples
self._training_docs = [
doc for doc in
self._load_nlp_dataset()["train"]
if doc["label"]
]
return self._training_docs
def fewshot_description(self):
return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \
" refers to.\n====="
def doc_to_text(self, doc, include_target=True):
raw_passage = doc["text"]
passage = (
raw_passage[:doc["span2_index"]]
+ "*{}*".format(doc["span2_text"])
+ raw_passage[doc["span2_index"] + len(doc["span2_text"]):]
)
pronoun = doc["span2_text"]
text = (
f"Passage: {passage}\n"
+ f"Question: In the passage above, what does the pronoun \"*{pronoun}*\" refer to?\n"
+ "Answer:"
)
if include_target:
text += " {}".format(doc["span1_text"])
return text
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in tqdm_lib.tqdm(docs):
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
to_predict = " " + doc["span1_text"]
num_tokens = len(lm.tokenizer.tokenize(to_predict))
generated = lm.generate(
context=ctx,
max_gen_length=num_tokens,
)
preds.append(1 if generated == to_predict else 0)
return simple_accuracy_metric(preds=preds, golds=golds)
...@@ -14,6 +14,7 @@ def parse_args(): ...@@ -14,6 +14,7 @@ def parse_args():
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
return parser.parse_args() return parser.parse_args()
...@@ -42,7 +43,11 @@ def main(): ...@@ -42,7 +43,11 @@ def main():
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
) )
results[task_name] = result results[task_name] = result
print(json.dumps(results, indent=2)) dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
with open(args.output_path, "w") as f:
f.write(dumped)
if __name__ == "__main__": if __name__ == "__main__":
......
import argparse
import numpy as np
import os
import random
from lm_eval import tasks
EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_examples', type=int, default=1)
return parser.parse_args()
def main():
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_dict = {
task_name: tasks.get_task(task_name)()
for task_name in task_names
}
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
if not task.has_validation_docs():
continue
docs = task.validation_docs()
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs):
f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context(
doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
f.write(ctx + "\n")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment