Commit 89c4ee5c authored by lintangsutawika's avatar lintangsutawika
Browse files

added include_benchmarks

parent 1d3a1601
import os
import yaml
from typing import List, Union
from lm_eval import utils
......@@ -39,9 +40,6 @@ def include_task_folder(task_dir):
)
if "task" in config:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
......@@ -57,8 +55,32 @@ def include_task_folder(task_dir):
)
def include_benchmarks(task_dir, benchmark_dir="benchmarks"):
for root, subdirs, file_list in os.walk(os.path.join(task_dir, benchmark_dir)):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
task_list = yaml_config["task"]
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
include_benchmarks(task_dir)
def get_task(task_name, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment