Commit efa99cb2 authored by Leo Gao's avatar Leo Gao
Browse files

TriviaQA memory fix

parent c4ecbd6d
...@@ -108,7 +108,7 @@ TASK_REGISTRY = { ...@@ -108,7 +108,7 @@ TASK_REGISTRY = {
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013, "qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA, # disabled pending memory fix "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
...@@ -3,26 +3,20 @@ import json ...@@ -3,26 +3,20 @@ import json
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from ..metrics import mean
from ..utils import sh from ..utils import sh
from best_download import download_file
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 0 VERSION = 0
def download(self): def download(self):
if not os.path.exists('data/triviaqa'): if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
# TODO: convert to best_download os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", "data/triviaqa/triviaqa-unfiltered.tar.gz", "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh(""" sh("""
mkdir -p data/triviaqa cd data/triviaqa/
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz tar -xf triviaqa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""") """)
# convert to streamable jsonl
for subset in ['train', 'dev']:
with open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.jsonl', 'w') as fh:
for d in json.load(open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.json'))['Data']:
fh.write(json.dumps(d) + "\n")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -33,10 +27,10 @@ class TriviaQA(Task): ...@@ -33,10 +27,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.jsonl')) return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl'))
def validation_docs(self): def validation_docs(self):
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.jsonl')) return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl'))
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment