You need to sign in or sign up before continuing.
Commit 52f270dc authored by Leo Gao's avatar Leo Gao
Browse files

triviaqa OOM fix

parent e4766cd7
import os
import json
import jsonlines
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
......@@ -27,10 +28,10 @@ class TriviaQA(Task):
return False
def training_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl'))
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl'))
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self):
raise NotImplementedError()
......
......@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task):
task2 = Task()
limit = None
if taskname in ["triviaqa"]: limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment