"vscode:/vscode.git/clone" did not exist on "a4c70fe157477fc25940a0ff1f544632464f2e77"
Commit 49cc6f5d authored by Leo Gao's avatar Leo Gao
Browse files

Allow specifying sets for write_out

parent 1c9432de
......@@ -25,3 +25,7 @@ def simple_parse_args_string(args_string):
k, v = arg.split("=")
args_dict[k] = v
return args_dict
def join_iters(iters):
for iter in iters:
yield from iter
\ No newline at end of file
......@@ -4,6 +4,7 @@ import os
import random
from lm_eval import tasks
from lm_eval.utils import join_iters
EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
......@@ -13,6 +14,7 @@ def parse_args():
parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_examples', type=int, default=1)
......@@ -31,10 +33,20 @@ def main():
task_dict = tasks.get_task_dict(task_names)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
if not task.has_validation_docs():
docs = task.training_docs()
else:
docs = task.validation_docs()
iters = []
for set in args.sets.split(","):
if set == 'train' and task.has_train_docs():
docs = task.train_docs()
if set == 'val' and task.has_validation_docs():
docs = task.validation_docs()
if set == 'test' and task.has_test_docs():
docs = task.test_docs()
iters.append(docs)
docs = join_iters(iters)
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment