Commit e971baac authored by Leo Gao's avatar Leo Gao
Browse files

Fix triviaqa memory consumption problem

(or rather, move it to the data download phase)
parent 2b8956b8
...@@ -101,7 +101,7 @@ TASK_REGISTRY = { ...@@ -101,7 +101,7 @@ TASK_REGISTRY = {
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013, "qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
...@@ -14,6 +14,12 @@ class TriviaQA(Task): ...@@ -14,6 +14,12 @@ class TriviaQA(Task):
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/ mv triviaqa-unfiltered/ data/triviaqa/
""") """)
# convert to streamable jsonl
for subset in ['train', 'dev']:
with open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.jsonl', 'w') as fh:
for d in json.load(open(f'data/triviaqa/triviaqa-unfiltered/unfiltered-web-{subset}.json'))['Data']:
fh.write(json.dumps(d) + "\n")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -25,20 +31,20 @@ class TriviaQA(Task): ...@@ -25,20 +31,20 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data'] return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.jsonl'))
def validation_docs(self): def validation_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.json'))['Data'] return map(json.loads, open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-val.jsonl'))
def test_docs(self): def test_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-test.json'))['Data'] raise NotImplementedError()
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ''.join(['Q:', doc['Question'], '\n\n','A:']) return f"Question: {doc['Question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['Answer']['Value'] return " " + doc['Answer']['Value']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment