"models/vision/ddim/example.py" did not exist on "25feac9e65ff7a7ca87d75150555bc010f3dfdd0"
Commit 52f270dc authored by Leo Gao's avatar Leo Gao
Browse files

triviaqa OOM fix

parent e4766cd7
import os import os
import json import json
import jsonlines
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from ..metrics import mean
from ..utils import sh from ..utils import sh
...@@ -27,10 +28,10 @@ class TriviaQA(Task): ...@@ -27,10 +28,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-train.jsonl')) return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self): def validation_docs(self):
return map(json.loads, open('data/triviaqa/unfiltered-web-dev.jsonl')) return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task): ...@@ -30,6 +30,8 @@ def test_basic_interface(taskname, Task):
task2 = Task() task2 = Task()
limit = None limit = None
if taskname in ["triviaqa"]: limit = 10000
if task.has_validation_docs(): if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit)) arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit)) arr2 = list(islice(task2.validation_docs(), limit))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment