Commit 9e98d1a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

removed test lines

parent e39ec52d
......@@ -477,10 +477,6 @@ def evaluate(
total_size = 0
for task in task_list:
print("###")
print(task)
print(metrics)
print("###")
metrics = results[task].copy()
if "alias" in metrics:
......@@ -492,6 +488,7 @@ def evaluate(
if weight_by_size:
current_size = metrics.pop("samples")
else:
metrics.pop("samples")
current_size = 1
all_stderr = []
......
......@@ -103,7 +103,6 @@ class TaskManager(abc.ABC):
task_object = (group, task_object)
return {task: task_object}
print("Loading", name_or_config)
if isinstance(name_or_config, str):
if self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
......@@ -198,30 +197,6 @@ class TaskManager(abc.ABC):
return tasks_and_groups
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
for group in group_name:
register_group(group)(SubClass)
return 0
def check_prompt_config(
config: Dict[str, str], yaml_path: str = None
......@@ -258,156 +233,66 @@ def check_prompt_config(
all_configs.append(config)
return all_configs
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
"""
Calling this function
"""
# Track whether any tasks failed during loading
import_fail = False
for root, subdirs, file_list in os.walk(task_dir):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
if "task" not in config:
continue
all_configs = check_prompt_config(
config, yaml_path=os.path.dirname(yaml_path)
)
for config in all_configs:
if register_task:
if type(config["task"]) == str:
register_configurable_task(config)
else:
if type(config["task"]) == list:
register_configurable_group(config, yaml_path)
# Log this silently and show it only when
# the user defines the appropriate verbosity.
except (ImportError, ModuleNotFoundError) as e:
import_fail = True
eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry."
)
except Exception as error:
import traceback
eval_logger.warning(
"Unexpected error loading config in\n"
f" {yaml_path}\n"
" Config will not be added to registry\n"
f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}"
)
if import_fail:
eval_logger.warning(
"Some tasks could not be loaded due to missing dependencies."
" Run with `--verbosity DEBUG` for full details."
)
return 0
def get_task(task_name, config):
try:
return TASK_REGISTRY[task_name](config=config)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = {
**task_name_from_registry_dict,
**task_dict,
}
else:
task_name = task_element
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_element, config=config),
}
elif isinstance(task_element, dict):
task_element.update(config)
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=task_element
),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# # TODO: pass num_fewshot and other cmdline overrides in a better way
# def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
# config = {**kwargs}
# task_name_from_registry_dict = {}
# task_name_from_config_dict = {}
# task_name_from_object_dict = {}
# if type(task_name_list) != list:
# task_name_list = [task_name_list]
# for task_element in task_name_list:
# if isinstance(task_element, str):
# if task_element in GROUP_REGISTRY:
# group_name = task_element
# for task_name in GROUP_REGISTRY[task_element]:
# if task_name not in task_name_from_registry_dict:
# task_obj = get_task_dict(task_name)
# if task_name in task_obj.keys():
# task_dict = {
# task_name: (group_name, task_obj[task_name]),
# }
# else:
# task_dict = {
# task_name: (group_name, None),
# **task_obj,
# }
# task_name_from_registry_dict = {
# **task_name_from_registry_dict,
# **task_dict,
# }
# else:
# task_name = task_element
# if task_name not in task_name_from_registry_dict:
# task_name_from_registry_dict = {
# **task_name_from_registry_dict,
# task_name: get_task(task_name=task_element, config=config),
# }
# elif isinstance(task_element, dict):
# task_element.update(config)
# task_name_from_config_dict = {
# **task_name_from_config_dict,
# get_task_name_from_config(task_element): ConfigurableTask(
# config=task_element
# ),
# }
# elif isinstance(task_element, Task):
# task_name_from_object_dict = {
# **task_name_from_object_dict,
# get_task_name_from_object(task_element): task_element,
# }
# assert set(task_name_from_registry_dict.keys()).isdisjoint(
# set(task_name_from_object_dict.keys())
# )
# return {
# **task_name_from_registry_dict,
# **task_name_from_config_dict,
# **task_name_from_object_dict,
# }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment