Commit 600918bb authored by lintangsutawika's avatar lintangsutawika
Browse files

can list other yamls to be included and loaded

parent 3cdc47c5
...@@ -13,6 +13,42 @@ from lm_eval.api.register import ( ...@@ -13,6 +13,42 @@ from lm_eval.api.register import (
group_registry group_registry
) )
def load_yaml_config(yaml_path):
with open(yaml_path, 'rb') as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if 'include' in yaml_config:
include_path = yaml_config['include']
del yaml_config['include']
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except:
# If failed to load, ignore
pass
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0): if (subdirs == []) and (len(file_list) > 0):
...@@ -20,7 +56,7 @@ for root, subdirs, file_list in os.walk(task_dir): ...@@ -20,7 +56,7 @@ for root, subdirs, file_list in os.walk(task_dir):
if "yaml" in file: if "yaml" in file:
yaml_path = os.path.join(root, file) yaml_path = os.path.join(root, file)
try: try:
config = yaml.full_load(open(yaml_path, "rb")) config = load_yaml_config(yaml_path)
SubClass = type( SubClass = type(
config['task']+'ConfigurableTask', config['task']+'ConfigurableTask',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment