"include/ck/utility/get_shift.hpp" did not exist on "f4dfc060b79987580da9afc481dad746d5b3178d"
Commit 1a159d6b authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness into winograd-fixes

parents 8b038c2a 7614a8f3
......@@ -7,44 +7,45 @@ from tqdm import tqdm
class GPT2LM(LM):
def __init__(self, device="cpu"):
def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>"
@classmethod
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests):
res = []
# TODO: vectorize properly
for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
with torch.no_grad():
# TODO: vectorize properly
for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
return res
......
......@@ -20,6 +20,7 @@ from . import triviaqa
from . import pubmedqa
from . import sciq
from . import webqs
from . import qa4mre
TASK_REGISTRY = {
......@@ -48,8 +49,13 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
#"qa4mre" : qa4mre.QA4MRE,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
......
......@@ -56,7 +56,10 @@ class Arithmetic(Task):
return doc.completion
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip(), completion=doc_json['completion'].strip())
return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
......
......@@ -28,8 +28,8 @@ class PiQA(HFTask):
return " " + solutions[doc["label"]]
def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, doc['sol2'])
ll_1, _ = rf.loglikelihood(ctx, " " + doc['sol1'])
ll_2, _ = rf.loglikelihood(ctx, " " + doc['sol2'])
return ll_1, ll_2
def process_results(self, doc, results):
......
import os
import numpy as np
from best_download import download_file
from lm_eval.base import MultipleChoiceTask, rf, mean
import xml.etree.ElementTree as ET
import random
class QA4MRE(MultipleChoiceTask):
YEAR = None
def download(self):
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.mkdir("data/qa4mre")
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
checksum=sha256sums[year],
)
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k):
# Since only test docs sample from test docs
if self._training_docs is None:
self._training_docs = list(self.test_docs())
return random.sample(self._training_docs, k)
def _convert_standard(self, question):
choices = [i.text for i in question.iter('answer')]
out_doc = {
"query" : question.find('q_str').text,
"choices": choices,
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1,
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def fewshot_description(self):
return ""
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
class QA4MRE_2011(QA4MRE):
YEAR = 2011
class QA4MRE_2012(QA4MRE):
YEAR = 2012
class QA4MRE_2013(QA4MRE):
YEAR = 2013
......@@ -82,9 +82,12 @@ class RACE(HFTask):
def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]:
assert problem['question'][-6:] == ' _ .'
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer
text += self.last_problem(doc)['question']
return text
......
......@@ -3,6 +3,7 @@ import json
from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile
from best_download import download_file
class SciQ(MultipleChoiceTask):
......@@ -10,9 +11,11 @@ class SciQ(MultipleChoiceTask):
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
'data/sciq/SciQ.zip',
'7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
......@@ -48,8 +51,6 @@ class SciQ(MultipleChoiceTask):
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
......
......@@ -218,7 +218,7 @@ class MultiRC(HFTask):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return self.format_answer(answer=doc["answer"], label=doc["label"])
return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
@staticmethod
def format_answer(answer, label):
......
......@@ -20,7 +20,7 @@ def parse_args():
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--cache', action="store_true")
parser.add_argument('--no_cache', action="store_true")
return parser.parse_args()
def main():
......@@ -31,7 +31,7 @@ def main():
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.cache:
if not args.no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment