Commit efa99cb2 authored by Leo Gao's avatar Leo Gao
Browse files

TriviaQA memory fix

parent c4ecbd6d
......@@ -108,7 +108,7 @@ TASK_REGISTRY = {
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA, # disabled pending memory fix
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
......
......@@ -3,25 +3,19 @@ import json
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
from best_download import download_file
class TriviaQA(Task):
VERSION = 0
def download(self):
if not os.path.exists('data/triviaqa'):
# TODO: convert to best_download
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", "data/triviaqa/triviaqa-unfiltered.tar.gz", "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh("""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz
""")
# convert to streamable jsonl
for subset in ['train', 'dev']:
with open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.jsonl', 'w') as fh:
for d in json.load(open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.json'))['Data']:
fh.write(json.dumps(d) + "\n")
def has_training_docs(self):
return True
......@@ -33,10 +27,10 @@ class TriviaQA(Task):
return False
def training_docs(self):
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.jsonl'))
return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl'))
def validation_docs(self):
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.jsonl'))
return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl'))
def test_docs(self):
raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment