Commit cb3babd3 authored by Leo Gao's avatar Leo Gao
Browse files

Improve pile test efficiency

parent 66fffcfc
......@@ -18,9 +18,10 @@ class PilePerplexityTask(PerplexityTask, abc.ABC):
def download(self):
# TODO: separate pile val/test out by component so we don't have to scan the entire file once per set
os.makedirs("data/pile/", exist_ok=True)
download_file("https://the-eye.eu/public/AI/pile/val.jsonl.zst", self.VAL_PATH, "264c875d8bbd355d8daa9d032b75fd8fb91606218bb84dd1155b203fcd5fab92")
download_file("https://the-eye.eu/public/AI/pile/test.jsonl.zst", self.TEST_PATH, "0bb28c52d0b5596d389bf179ce2d43bf7f7ffae76b0d2d20b180c97f62e0975e")
if not os.path.exists("data/pile/test.jsonl.zst"):
os.makedirs("data/pile/", exist_ok=True)
download_file("https://the-eye.eu/public/AI/pile/val.jsonl.zst", self.VAL_PATH, "264c875d8bbd355d8daa9d032b75fd8fb91606218bb84dd1155b203fcd5fab92")
download_file("https://the-eye.eu/public/AI/pile/test.jsonl.zst", self.TEST_PATH, "0bb28c52d0b5596d389bf179ce2d43bf7f7ffae76b0d2d20b180c97f62e0975e")
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
......
......@@ -32,7 +32,7 @@ def test_basic_interface(taskname, task_class):
limit = None
if taskname in ["triviaqa"]:
if taskname in ["triviaqa"] or taskname.startswith("pile_"):
limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment