Commit 9e98d1a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

removed test lines

parent e39ec52d
...@@ -477,10 +477,6 @@ def evaluate( ...@@ -477,10 +477,6 @@ def evaluate(
total_size = 0 total_size = 0
for task in task_list: for task in task_list:
print("###")
print(task)
print(metrics)
print("###")
metrics = results[task].copy() metrics = results[task].copy()
if "alias" in metrics: if "alias" in metrics:
...@@ -492,6 +488,7 @@ def evaluate( ...@@ -492,6 +488,7 @@ def evaluate(
if weight_by_size: if weight_by_size:
current_size = metrics.pop("samples") current_size = metrics.pop("samples")
else: else:
metrics.pop("samples")
current_size = 1 current_size = 1
all_stderr = [] all_stderr = []
......
...@@ -103,7 +103,6 @@ class TaskManager(abc.ABC): ...@@ -103,7 +103,6 @@ class TaskManager(abc.ABC):
task_object = (group, task_object) task_object = (group, task_object)
return {task: task_object} return {task: task_object}
print("Loading", name_or_config)
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
if self._name_is_task(name_or_config): if self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
...@@ -198,30 +197,6 @@ class TaskManager(abc.ABC): ...@@ -198,30 +197,6 @@ class TaskManager(abc.ABC):
return tasks_and_groups return tasks_and_groups
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
for group in group_name:
register_group(group)(SubClass)
return 0
def check_prompt_config( def check_prompt_config(
config: Dict[str, str], yaml_path: str = None config: Dict[str, str], yaml_path: str = None
...@@ -258,156 +233,66 @@ def check_prompt_config( ...@@ -258,156 +233,66 @@ def check_prompt_config(
all_configs.append(config) all_configs.append(config)
return all_configs return all_configs
# # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_name_from_config(task_config: Dict[str, str]) -> str: # def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
if "dataset_name" in task_config: # config = {**kwargs}
return "{dataset_path}_{dataset_name}".format(**task_config)
else: # task_name_from_registry_dict = {}
return "{dataset_path}".format(**task_config) # task_name_from_config_dict = {}
# task_name_from_object_dict = {}
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None: # if type(task_name_list) != list:
""" # task_name_list = [task_name_list]
Calling this function
""" # for task_element in task_name_list:
# if isinstance(task_element, str):
# Track whether any tasks failed during loading # if task_element in GROUP_REGISTRY:
import_fail = False # group_name = task_element
for root, subdirs, file_list in os.walk(task_dir): # for task_name in GROUP_REGISTRY[task_element]:
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): # if task_name not in task_name_from_registry_dict:
for f in file_list: # task_obj = get_task_dict(task_name)
if f.endswith(".yaml"): # if task_name in task_obj.keys():
yaml_path = os.path.join(root, f) # task_dict = {
try: # task_name: (group_name, task_obj[task_name]),
config = utils.load_yaml_config(yaml_path) # }
# else:
if "task" not in config: # task_dict = {
continue # task_name: (group_name, None),
# **task_obj,
all_configs = check_prompt_config( # }
config, yaml_path=os.path.dirname(yaml_path)
) # task_name_from_registry_dict = {
for config in all_configs: # **task_name_from_registry_dict,
if register_task: # **task_dict,
if type(config["task"]) == str: # }
register_configurable_task(config) # else:
else: # task_name = task_element
if type(config["task"]) == list: # if task_name not in task_name_from_registry_dict:
register_configurable_group(config, yaml_path) # task_name_from_registry_dict = {
# **task_name_from_registry_dict,
# Log this silently and show it only when # task_name: get_task(task_name=task_element, config=config),
# the user defines the appropriate verbosity. # }
except (ImportError, ModuleNotFoundError) as e:
import_fail = True # elif isinstance(task_element, dict):
eval_logger.debug( # task_element.update(config)
f"{yaml_path}: {e}. Config will not be added to registry." # task_name_from_config_dict = {
) # **task_name_from_config_dict,
except Exception as error: # get_task_name_from_config(task_element): ConfigurableTask(
import traceback # config=task_element
# ),
eval_logger.warning( # }
"Unexpected error loading config in\n"
f" {yaml_path}\n" # elif isinstance(task_element, Task):
" Config will not be added to registry\n" # task_name_from_object_dict = {
f" Error: {error}\n" # **task_name_from_object_dict,
f" Traceback: {traceback.format_exc()}" # get_task_name_from_object(task_element): task_element,
) # }
if import_fail: # assert set(task_name_from_registry_dict.keys()).isdisjoint(
eval_logger.warning( # set(task_name_from_object_dict.keys())
"Some tasks could not be loaded due to missing dependencies." # )
" Run with `--verbosity DEBUG` for full details." # return {
) # **task_name_from_registry_dict,
return 0 # **task_name_from_config_dict,
# **task_name_from_object_dict,
# }
def get_task(task_name, config):
try:
return TASK_REGISTRY[task_name](config=config)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = {
**task_name_from_registry_dict,
**task_dict,
}
else:
task_name = task_element
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_element, config=config),
}
elif isinstance(task_element, dict):
task_element.update(config)
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=task_element
),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment