Unverified Commit a0243d54 authored by anthony-dipofi's avatar anthony-dipofi Committed by GitHub
Browse files

Prettify lm_eval --tasks list (#1929)



* add  and ; move task list newline logic to new TaskManager.list_all_tasks() method

* format table list into markdown table; add config location column

* add Output Type column

* add logic for printing table of tags separately

* merge with main and fix conflicts ; update docstrings

---------
Co-authored-by: default avatarhaileyschoelkopf <hailey@eleuther.ai>
parent 30273b47
......@@ -73,7 +73,7 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
type=str,
metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list",
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
)
parser.add_argument(
"--model_args",
......@@ -318,9 +318,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
)
print(task_manager.list_all_tasks())
sys.exit()
elif args.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit()
elif args.tasks == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit()
elif args.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit()
else:
if os.path.isdir(args.tasks):
......@@ -349,7 +356,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
)
raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
)
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
......
......@@ -36,6 +36,16 @@ class TaskManager:
)
self._all_tasks = sorted(list(self._task_index.keys()))
self._all_groups = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
)
self._all_subtasks = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "task"]
)
self._all_tags = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
)
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(
......@@ -73,10 +83,88 @@ class TaskManager:
def all_tasks(self):
return self._all_tasks
@property
def all_groups(self):
return self._all_groups
@property
def all_subtasks(self):
return self._all_subtasks
@property
def all_tags(self):
return self._all_tags
@property
def task_index(self):
return self._task_index
def list_all_tasks(
self, list_groups=True, list_tags=True, list_subtasks=True
) -> str:
from pytablewriter import MarkdownTableWriter
def sanitize_path(path):
# don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path.
if "lm_eval/tasks/" in path:
return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
else:
return path
group_table = MarkdownTableWriter()
group_table.headers = ["Group", "Config Location"]
gt_values = []
for g in self.all_groups:
path = self.task_index[g]["yaml_path"]
if path == -1:
path = "---"
else:
path = sanitize_path(path)
gt_values.append([g, path])
group_table.value_matrix = gt_values
tag_table = MarkdownTableWriter()
tag_table.headers = ["Tag"]
tag_table.value_matrix = [[t] for t in self.all_tags]
subtask_table = MarkdownTableWriter()
subtask_table.headers = ["Task", "Config Location", "Output Type"]
st_values = []
for t in self.all_subtasks:
path = self.task_index[t]["yaml_path"]
output_type = ""
# read the yaml file to determine the output type
if path != -1:
config = utils.load_yaml_config(path, mode="simple")
if "output_type" in config:
output_type = config["output_type"]
elif (
"include" in config
): # if no output type, check if there is an include with an output type
include_path = path.split("/")[:-1] + config["include"]
include_config = utils.load_yaml_config(include_path, mode="simple")
if "output_type" in include_config:
output_type = include_config["output_type"]
if path == -1:
path = "---"
else:
path = sanitize_path(path)
st_values.append([t, path, output_type])
subtask_table.value_matrix = st_values
result = "\n"
if list_groups:
result += group_table.dumps() + "\n\n"
if list_tags:
result += tag_table.dumps() + "\n\n"
if list_subtasks:
result += subtask_table.dumps() + "\n\n"
return result
def match_tasks(self, task_list):
return utils.pattern_match(task_list, self.all_tasks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment