"vscode:/vscode.git/clone" did not exist on "af97ec2f4c9daac091b9a87355c4f22d37488004"
Commit d4ae0c00 authored by Leo Gao's avatar Leo Gao
Browse files

Allow specifying sets for write_out

parent 1c9432de
...@@ -25,3 +25,7 @@ def simple_parse_args_string(args_string): ...@@ -25,3 +25,7 @@ def simple_parse_args_string(args_string):
k, v = arg.split("=") k, v = arg.split("=")
args_dict[k] = v args_dict[k] = v
return args_dict return args_dict
def join_iters(iters):
for iter in iters:
yield from iter
\ No newline at end of file
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import random import random
from lm_eval import tasks from lm_eval import tasks
from lm_eval.utils import join_iters
EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n" EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
...@@ -13,6 +14,7 @@ def parse_args(): ...@@ -13,6 +14,7 @@ def parse_args():
parser.add_argument('--output_base_path', required=True) parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument('--num_examples', type=int, default=1)
...@@ -31,10 +33,21 @@ def main(): ...@@ -31,10 +33,21 @@ def main():
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if not task.has_validation_docs():
docs = task.training_docs() iters = []
else:
docs = task.validation_docs() for set in args.sets.split(","):
if set == 'val':
if task.has_train_docs():
docs = task.train_docs()
elif task.has_validation_docs():
docs = task.validation_docs()
elif task.has_test_docs():
docs = task.test_docs()
iters.append(docs)
docs = join_iters(iters)
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment