Commit 52f270dc authored by Leo Gao's avatar Leo Gao
Browse files

triviaqa OOM fix

parent e4766cd7
import os import os
import json import json
import jsonlines
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from ..metrics import mean
from ..utils import sh from ..utils import sh
...@@ -27,10 +28,10 @@ class TriviaQA(Task): ...@@ -27,10 +28,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl')) return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self): def validation_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl')) return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task): ...@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task):
task2 = Task() task2 = Task()
limit = None limit = None
if taskname in ["triviaqa"]: limit = 10000
if task.has_validation_docs(): if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit)) arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit)) arr2 = list(islice(task2.validation_docs(), limit))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment