Commit 1a159d6b authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness into winograd-fixes

parents 8b038c2a 7614a8f3
...@@ -7,44 +7,45 @@ from tqdm import tqdm ...@@ -7,44 +7,45 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
def __init__(self, device="cpu"): def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
# TODO: vectorize properly with torch.no_grad():
for context, continuation in tqdm(requests): # TODO: vectorize properly
# when too long to fit in context, truncate from the left for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
if context == "":
# end of text as context if context == "":
context_enc = [50256] # end of text as context
else: context_enc = [50256]
context_enc = self.tokenizer.encode(context) else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device) continuation_enc = self.tokenizer.encode(continuation)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024) inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all() greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
res.append((float(logits.sum()), bool(max_equal)))
return res return res
......
...@@ -20,6 +20,7 @@ from . import triviaqa ...@@ -20,6 +20,7 @@ from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
from . import webqs from . import webqs
from . import qa4mre
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -48,8 +49,13 @@ TASK_REGISTRY = { ...@@ -48,8 +49,13 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq" : sciq.SciQ,
#"qa4mre" : qa4mre.QA4MRE,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
......
...@@ -56,7 +56,10 @@ class Arithmetic(Task): ...@@ -56,7 +56,10 @@ class Arithmetic(Task):
return doc.completion return doc.completion
def load_doc(self, doc_json): def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip(), completion=doc_json['completion'].strip()) return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion) ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
......
...@@ -28,8 +28,8 @@ class PiQA(HFTask): ...@@ -28,8 +28,8 @@ class PiQA(HFTask):
return " " + solutions[doc["label"]] return " " + solutions[doc["label"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc['sol1']) ll_1, _ = rf.loglikelihood(ctx, " " + doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc['sol2']) ll_2, _ = rf.loglikelihood(ctx, " " + doc['sol2'])
return ll_1, ll_2 return ll_1, ll_2
def process_results(self, doc, results): def process_results(self, doc, results):
......
import os
import numpy as np
from best_download import download_file
from lm_eval.base import MultipleChoiceTask, rf, mean
import xml.etree.ElementTree as ET
import random
class QA4MRE(MultipleChoiceTask):
YEAR = None
def download(self):
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.mkdir("data/qa4mre")
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
checksum=sha256sums[year],
)
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k):
# Since only test docs sample from test docs
if self._training_docs is None:
self._training_docs = list(self.test_docs())
return random.sample(self._training_docs, k)
def _convert_standard(self, question):
choices = [i.text for i in question.iter('answer')]
out_doc = {
"query" : question.find('q_str').text,
"choices": choices,
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1,
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def fewshot_description(self):
return ""
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
class QA4MRE_2011(QA4MRE):
YEAR = 2011
class QA4MRE_2012(QA4MRE):
YEAR = 2012
class QA4MRE_2013(QA4MRE):
YEAR = 2013
...@@ -82,9 +82,12 @@ class RACE(HFTask): ...@@ -82,9 +82,12 @@ class RACE(HFTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n' text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]: for problem in doc['problems'][:-1]:
assert problem['question'][-6:] == ' _ .' if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer
text += self.last_problem(doc)['question'] text += self.last_problem(doc)['question']
return text return text
......
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
from ..utils import sh from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile import zipfile
from best_download import download_file
class SciQ(MultipleChoiceTask): class SciQ(MultipleChoiceTask):
...@@ -10,9 +11,11 @@ class SciQ(MultipleChoiceTask): ...@@ -10,9 +11,11 @@ class SciQ(MultipleChoiceTask):
def download(self): def download(self):
if not os.path.exists('data/sciq'): if not os.path.exists('data/sciq'):
os.mkdir('data/sciq') os.mkdir('data/sciq')
sh(( download_file(
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip" 'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
)) 'data/sciq/SciQ.zip',
'7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf: with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/") zf.extractall("data/sciq/")
...@@ -48,8 +51,6 @@ class SciQ(MultipleChoiceTask): ...@@ -48,8 +51,6 @@ class SciQ(MultipleChoiceTask):
yield self._convert_standard(record) yield self._convert_standard(record)
def fewshot_description(self): def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return "" return ""
def training_docs(self): def training_docs(self):
......
...@@ -218,7 +218,7 @@ class MultiRC(HFTask): ...@@ -218,7 +218,7 @@ class MultiRC(HFTask):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return self.format_answer(answer=doc["answer"], label=doc["label"]) return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
@staticmethod @staticmethod
def format_answer(answer, label): def format_answer(answer, label):
......
...@@ -20,7 +20,7 @@ def parse_args(): ...@@ -20,7 +20,7 @@ def parse_args():
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
return parser.parse_args() return parser.parse_args()
def main(): def main():
...@@ -31,7 +31,7 @@ def main(): ...@@ -31,7 +31,7 @@ def main():
lm = models.get_model(args.model).create_from_arg_string(args.model_args) lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.cache: if not args.no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db') lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment