"vscode:/vscode.git/clone" did not exist on "0bae0ff47d4bb225b6587726de8d444cfd2b275c"
Commit 62572f05 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust group scoring with using ConfigurableGroup

parent fe2a1472
......@@ -22,7 +22,7 @@ from lm_eval.evaluator_utils import (
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask, TaskManager, get_task_dict
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
......@@ -204,17 +204,24 @@ def simple_evaluate(
+ ".db",
)
if check_integrity:
run_task_tests(task_list=tasks)
if task_manager is None:
task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if isinstance(task_obj, lm_eval.api.task.ConfigurableTask) is False:
continue
def _adjust_config(task_dict):
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj)}
}
else:
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
......@@ -246,9 +253,11 @@ def simple_evaluate(
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
task_obj.set_config(key="num_fewshot", value=0)
if check_integrity:
run_task_tests(task_list=tasks)
adjusted_task_dict[task_name] = task_obj
return adjusted_task_dict
task_dict = _adjust_config(task_dict)
results = evaluate(
lm=lm,
task_dict=task_dict,
......@@ -330,7 +339,10 @@ def evaluate(
padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
eval_tasks = get_task_list(task_dict)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if not log_samples:
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
......@@ -484,24 +496,36 @@ def evaluate(
### Calculate group metrics ###
if bool(results):
show_group_table = False
for group, group_info in reversed(task_hierarchy.items()):
task_list = group_info["tasks"]
if len(task_list) == 0:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
def process_group(results, task_dict, task_root=None, task_hierarchy=None, show_group_table=False):
group_config = lm_eval.api.task.GroupConfig(
**group_info["config"] if "config" in group_info else {}
)
if task_root is None:
task_root = {}
if task_hierarchy is None:
task_hierarchy = {}
for group_or_task, group_or_task_info in task_dict.items():
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(group_or_task)
else:
results, _task_hierarchy, show_group_table = process_group(results, group_or_task_info, group_or_task, task_hierarchy, show_group_table)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(task_hierarchy.get(group_or_task, []))
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group = group_or_task.group
show_group_table = show_group_table | group_config["aggregate_metric"]
if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
elif isinstance(group_or_task, str):
results[group_or_task][" "] = " "
continue
task_list = _task_hierarchy[group_or_task]
metric_list = list(
{
key
......@@ -550,6 +574,9 @@ def evaluate(
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes)
return results, task_hierarchy, show_group_table
results, task_hierarchy, show_group_table = process_group(results, task_dict)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
......@@ -575,6 +602,7 @@ def evaluate(
task_list[0]
] # TODO: validate this
import sys; sys.exit()
results_dict = {
"results": dict(results_agg.items()),
**(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment