Commit e971baac authored by Leo Gao's avatar Leo Gao
Browse files

Fix triviaqa memory consumption problem

(or rather, move it to the data download phase)
parent 2b8956b8
......@@ -101,7 +101,7 @@ TASK_REGISTRY = {
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
......
......@@ -14,6 +14,12 @@ class TriviaQA(Task):
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""")
# convert to streamable jsonl
for subset in ['train', 'dev']:
with open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.jsonl', 'w') as fh:
for d in json.load(open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.json'))['Data']:
fh.write(json.dumps(d) + "\n")
def has_training_docs(self):
return True
......@@ -25,20 +31,20 @@ class TriviaQA(Task):
return False
def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data']
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.jsonl'))
def validation_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.json'))['Data']
return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-val.jsonl'))
def test_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-test.json'))['Data']
raise NotImplementedError()
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return ''.join(['Q:', doc['Question'], '\n\n','A:'])
return f"Question: {doc['Question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + doc['Answer']['Value']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment