Commit ad70d206 authored by lintangsutawika's avatar lintangsutawika
Browse files

update to work with new group and task configuration

parent c23c9305
......@@ -60,6 +60,7 @@ class GroupConfig(dict):
aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False
metric_alias: Optional[str] = None
version: Optional[str] = 0
def __getitem__(self, item):
return getattr(self, item)
......@@ -113,21 +114,25 @@ class ConfigurableGroup(abc.ABC):
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def version(self):
return self._config.version
@property
def config(self):
return self._config.to_dict()
def __repr__(self):
return (
f"ConfigurableGroup(group={self.group},"
f"group_alias={self.group_alias})"
f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
)
@dataclass
class TaskConfig(dict):
# task naming/registry
......
......@@ -17,12 +17,16 @@ from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
get_task_list,
prepare_print_tasks,
print_writeout,
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask, TaskManager, get_task_dict
from lm_eval.tasks import (
ConfigurableGroup,
ConfigurableTask,
TaskManager,
get_task_dict,
)
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
......@@ -211,14 +215,14 @@ def simple_evaluate(
task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager)
def _adjust_config(task_dict):
def _adjust_config(task_dict):
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj)}
**{task_name: _adjust_config(task_obj)},
}
else:
......@@ -229,7 +233,6 @@ def simple_evaluate(
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
......@@ -250,9 +253,11 @@ def simple_evaluate(
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
if (
default_num_fewshot := task_obj.get_config("num_fewshot")
) is None:
task_obj.set_config(key="num_fewshot", value=0)
adjusted_task_dict[task_name] = task_obj
return adjusted_task_dict
......@@ -266,7 +271,7 @@ def simple_evaluate(
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
write_out=write_out,
log_samples=log_samples,
log_samples=True if predict_only else log_samples,
verbosity=verbosity,
)
......@@ -340,9 +345,6 @@ def evaluate(
# get lists of group hierarchy and each type of request
eval_tasks = get_task_list(task_dict)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if not log_samples:
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
......@@ -496,8 +498,14 @@ def evaluate(
### Calculate group metrics ###
if bool(results):
def process_group(results, task_dict, task_root=None, task_hierarchy=None, show_group_table=False):
def process_group(
results,
task_dict,
task_root=None,
task_hierarchy=None,
show_group_table=False,
):
if task_root is None:
task_root = {}
......@@ -505,25 +513,33 @@ def evaluate(
task_hierarchy = {}
for group_or_task, group_or_task_info in task_dict.items():
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.group
show_group_table = (
show_group_table | group_config["aggregate_metric"]
)
if group_config["aggregate_metric"] is False:
results[group_or_task][" "] = " "
continue
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(group_or_task)
task_hierarchy.setdefault(task_root, []).append(
group_or_task
)
else:
results, _task_hierarchy, show_group_table = process_group(results, group_or_task_info, group_or_task, task_hierarchy, show_group_table)
results, _task_hierarchy, show_group_table = process_group(
results,
group_or_task_info,
group_or_task,
task_hierarchy,
show_group_table,
)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(task_hierarchy.get(group_or_task, []))
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group = group_or_task.group
show_group_table = show_group_table | group_config["aggregate_metric"]
if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
elif isinstance(group_or_task, str):
results[group_or_task][" "] = " "
continue
task_hierarchy.setdefault(task_root, []).extend(
task_hierarchy.get(group_or_task, [])
)
task_list = _task_hierarchy[group_or_task]
metric_list = list(
......@@ -531,7 +547,8 @@ def evaluate(
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
if "_stderr" not in key
and key not in ["alias", "samples"]
}
)
for metric in metric_list:
......@@ -555,7 +572,7 @@ def evaluate(
]
# compute group's pooled metric and stderr
results[group][
results[group_or_task][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics,
......@@ -564,54 +581,69 @@ def evaluate(
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
results[group_or_task][stderr] = "N/A"
else:
results[group][
results[group_or_task][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
results[group_or_task]["samples"] = sum(sizes)
return results, task_hierarchy, show_group_table
results, task_hierarchy, show_group_table = process_group(results, task_dict)
print(task_hierarchy)
import sys; sys.exit()
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v["tasks"] for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, group_info in task_hierarchy.items():
task_list = group_info["tasks"]
if task_list:
num_fewshot[group_name] = num_fewshot[
task_list[0]
] # TODO: validate this
import sys; sys.exit()
results, task_hierarchy, show_group_table = process_group(
results, task_dict
)
def print_table(task_dict, results, task_depth=0):
task_agg = defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else:
alias = name
else:
if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
else:
alias = name
task_agg[name]["alias"] = tab_string + alias
if "samples" in task_agg[name]:
task_agg[name].pop("samples")
if isinstance(task_or_group_obj, dict):
task_depth += 1
task_agg = {
**task_agg,
**print_table(task_or_group_obj, results, task_depth),
}
task_depth -= 1
return task_agg
results_agg = print_table(task_dict, results)
results_dict = {
"results": dict(results_agg.items()),
**(
{"groups": dict(groups_agg.items())}
if (bool(groups_agg) & show_group_table)
else {}
),
# **(
# {"groups": dict(groups_agg.items())}
# if (bool(groups_agg) & show_group_table)
# else {}
# ),
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
......
......@@ -2,11 +2,11 @@ import collections
import math
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup
class TaskOutput:
"""
......@@ -121,14 +121,13 @@ class TaskOutput:
def get_task_list(task_dict: dict) -> List[TaskOutput]:
outputs = []
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
_outputs = get_task_list(task_obj)
outputs.extend(_outputs)
else:
task_output = TaskOutput.from_taskdict(task_name, task_obj)
outputs.append(task_output)
return outputs
......
......@@ -5,7 +5,8 @@ from functools import partial
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup, GroupConfig, Task
from lm_eval.api.task import ConfigurableGroup, ConfigurableTask, GroupConfig, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
......
......@@ -6,3 +6,4 @@ task:
- mmlu_humanities
aggregate_metric: True
weight_by_size: True
version: 1
......@@ -16,3 +16,4 @@ task:
# - mmlu_world_religions
aggregate_metric: True
weight_by_size: True
version: 1
......@@ -16,3 +16,4 @@ task:
# - mmlu_virology
aggregate_metric: True
weight_by_size: True
version: 1
......@@ -15,3 +15,4 @@ task:
# - mmlu_us_foreign_policy
aggregate_metric: True
weight_by_size: True
version: 1
......@@ -22,3 +22,4 @@ task:
# - mmlu_machine_learning
aggregate_metric: True
weight_by_size: True
version: 1
......@@ -242,8 +242,11 @@ def make_table(result_dict, column: str = "results"):
values = []
for k, dic in result_dict[column].items():
version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k])
version = result_dict["versions"].get(k, " N/A")
if k in result_dict["n-shot"]:
n = str(result_dict["n-shot"][k])
else:
n = " "
if "alias" in dic:
k = dic.pop("alias")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment