Commit 62572f05 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust group scoring with using ConfigurableGroup

parent fe2a1472
......@@ -22,7 +22,7 @@ from lm_eval.evaluator_utils import (
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask, TaskManager, get_task_dict
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
......@@ -204,51 +204,60 @@ def simple_evaluate(
+ ".db",
)
if check_integrity:
run_task_tests(task_list=tasks)
if task_manager is None:
task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if isinstance(task_obj, lm_eval.api.task.ConfigurableTask) is False:
continue
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")
# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
task_obj.set_config(key="num_fewshot", value=0)
def _adjust_config(task_dict):
if check_integrity:
run_task_tests(task_list=tasks)
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj)}
}
else:
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")
# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
task_obj.set_config(key="num_fewshot", value=0)
adjusted_task_dict[task_name] = task_obj
return adjusted_task_dict
task_dict = _adjust_config(task_dict)
results = evaluate(
lm=lm,
task_dict=task_dict,
......@@ -330,7 +339,10 @@ def evaluate(
padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
eval_tasks = get_task_list(task_dict)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if not log_samples:
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
......@@ -484,72 +496,87 @@ def evaluate(
### Calculate group metrics ###
if bool(results):
show_group_table = False
for group, group_info in reversed(task_hierarchy.items()):
task_list = group_info["tasks"]
if len(task_list) == 0:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
group_config = lm_eval.api.task.GroupConfig(
**group_info["config"] if "config" in group_info else {}
)
show_group_table = show_group_table | group_config["aggregate_metric"]
if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics,
sizes,
group_config["weight_by_size"],
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
def process_group(results, task_dict, task_root=None, task_hierarchy=None, show_group_table=False):
results[group]["samples"] = sum(sizes)
if task_root is None:
task_root = {}
if task_hierarchy is None:
task_hierarchy = {}
for group_or_task, group_or_task_info in task_dict.items():
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(group_or_task)
else:
results, _task_hierarchy, show_group_table = process_group(results, group_or_task_info, group_or_task, task_hierarchy, show_group_table)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(task_hierarchy.get(group_or_task, []))
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group = group_or_task.group
show_group_table = show_group_table | group_config["aggregate_metric"]
if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
elif isinstance(group_or_task, str):
results[group_or_task][" "] = " "
continue
task_list = _task_hierarchy[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics,
sizes,
group_config["weight_by_size"],
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
return results, task_hierarchy, show_group_table
results, task_hierarchy, show_group_table = process_group(results, task_dict)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
......@@ -575,6 +602,7 @@ def evaluate(
task_list[0]
] # TODO: validate this
import sys; sys.exit()
results_dict = {
"results": dict(results_agg.items()),
**(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment