Commit 52f270dc authored by Leo Gao's avatar Leo Gao
Browse files

triviaqa OOM fix

parent e4766cd7
import os
import json
import jsonlines
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
......@@ -27,10 +28,10 @@ class TriviaQA(Task):
return False
def training_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl'))
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl'))
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self):
raise NotImplementedError()
......
......@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task):
task2 = Task()
limit = None
if taskname in ["triviaqa"]: limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment