Commit 7aee2dff authored by lintangsutawika's avatar lintangsutawika
Browse files

moved functions out and some fixes

parent 66c58194
import os import os
import re
import yaml
from typing import List, Union from typing import List, Union
from .arc import * from .arc import *
from lm_eval import utils
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import ( from lm_eval.api.register import (
register_task, register_task,
...@@ -14,39 +13,8 @@ from lm_eval.api.register import ( ...@@ -14,39 +13,8 @@ from lm_eval.api.register import (
) )
def load_yaml_config(yaml_path): def get_task_name_from_config(task_config):
with open(yaml_path, 'rb') as file: return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if 'include' in yaml_config:
include_path = yaml_config['include']
del yaml_config['include']
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except:
# If failed to load, ignore
pass
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
...@@ -56,7 +24,7 @@ for root, subdirs, file_list in os.walk(task_dir): ...@@ -56,7 +24,7 @@ for root, subdirs, file_list in os.walk(task_dir):
if "yaml" in file: if "yaml" in file:
yaml_path = os.path.join(root, file) yaml_path = os.path.join(root, file)
try: try:
config = load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
SubClass = type( SubClass = type(
config['task']+'ConfigurableTask', config['task']+'ConfigurableTask',
...@@ -65,13 +33,17 @@ for root, subdirs, file_list in os.walk(task_dir): ...@@ -65,13 +33,17 @@ for root, subdirs, file_list in os.walk(task_dir):
) )
if 'task' in config: if 'task' in config:
register_task(config['task'])(SubClass) task_name = "{}:{}".format(
get_task_name_from_config(config),
config['task']
)
register_task(task_name)(SubClass)
if 'group' in config: if 'group' in config:
for group in config['group']: for group in config['group']:
register_group(group)(SubClass) register_group(group)(SubClass)
except: except:
pass pass
TASK_REGISTRY = task_registry TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry GROUP_REGISTRY = group_registry
...@@ -100,10 +72,6 @@ def get_task_name_from_object(task_object): ...@@ -100,10 +72,6 @@ def get_task_name_from_object(task_object):
) )
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...@@ -126,6 +94,7 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -126,6 +94,7 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
) )
} }
else: else:
task_name = task_element
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
...@@ -135,11 +104,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -135,11 +104,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
} }
elif isinstance(task_element, dict): elif isinstance(task_element, dict):
task_element.update(config)
task_name_from_config_dict = { task_name_from_config_dict = {
**task_name_from_config_dict, **task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask( get_task_name_from_config(task_element): ConfigurableTask(
config=config config=task_element
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment