Commit ad70d206 authored by lintangsutawika's avatar lintangsutawika
Browse files

update to work with new group and task configuration

parent c23c9305
...@@ -60,6 +60,7 @@ class GroupConfig(dict): ...@@ -60,6 +60,7 @@ class GroupConfig(dict):
aggregate_fn: Optional[str] = "mean" aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False weight_by_size: Optional[str] = False
metric_alias: Optional[str] = None metric_alias: Optional[str] = None
version: Optional[str] = 0
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -118,16 +119,20 @@ class ConfigurableGroup(abc.ABC): ...@@ -118,16 +119,20 @@ class ConfigurableGroup(abc.ABC):
def group_alias(self): def group_alias(self):
return self._config.group_alias return self._config.group_alias
@property
def version(self):
return self._config.version
@property @property
def config(self): def config(self):
return self._config.to_dict() return self._config.to_dict()
def __repr__(self): def __repr__(self):
return ( return (
f"ConfigurableGroup(group={self.group}," f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
f"group_alias={self.group_alias})"
) )
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
......
...@@ -17,12 +17,16 @@ from lm_eval.evaluator_utils import ( ...@@ -17,12 +17,16 @@ from lm_eval.evaluator_utils import (
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
get_task_list, get_task_list,
prepare_print_tasks,
print_writeout, print_writeout,
run_task_tests, run_task_tests,
) )
from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask, TaskManager, get_task_dict from lm_eval.tasks import (
ConfigurableGroup,
ConfigurableTask,
TaskManager,
get_task_dict,
)
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
...@@ -211,14 +215,14 @@ def simple_evaluate( ...@@ -211,14 +215,14 @@ def simple_evaluate(
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager) task_dict = get_task_dict(tasks, task_manager)
def _adjust_config(task_dict):
def _adjust_config(task_dict):
adjusted_task_dict = {} adjusted_task_dict = {}
for task_name, task_obj in task_dict.items(): for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
adjusted_task_dict = { adjusted_task_dict = {
**adjusted_task_dict, **adjusted_task_dict,
**{task_name: _adjust_config(task_obj)} **{task_name: _adjust_config(task_obj)},
} }
else: else:
...@@ -229,7 +233,6 @@ def simple_evaluate( ...@@ -229,7 +233,6 @@ def simple_evaluate(
) )
if predict_only: if predict_only:
log_samples = True
eval_logger.info( eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!" f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
) )
...@@ -250,7 +253,9 @@ def simple_evaluate( ...@@ -250,7 +253,9 @@ def simple_evaluate(
task_obj.set_config(key="num_fewshot", value=num_fewshot) task_obj.set_config(key="num_fewshot", value=num_fewshot)
else: else:
# if num_fewshot not provided, and the task does not define a default one, default to 0 # if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: if (
default_num_fewshot := task_obj.get_config("num_fewshot")
) is None:
task_obj.set_config(key="num_fewshot", value=0) task_obj.set_config(key="num_fewshot", value=0)
adjusted_task_dict[task_name] = task_obj adjusted_task_dict[task_name] = task_obj
...@@ -266,7 +271,7 @@ def simple_evaluate( ...@@ -266,7 +271,7 @@ def simple_evaluate(
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
write_out=write_out, write_out=write_out,
log_samples=log_samples, log_samples=True if predict_only else log_samples,
verbosity=verbosity, verbosity=verbosity,
) )
...@@ -340,9 +345,6 @@ def evaluate( ...@@ -340,9 +345,6 @@ def evaluate(
# get lists of group hierarchy and each type of request # get lists of group hierarchy and each type of request
eval_tasks = get_task_list(task_dict) eval_tasks = get_task_list(task_dict)
# print("task_hierarchy")
# print(task_hierarchy)
# import sys; sys.exit()
if not log_samples: if not log_samples:
if not all( if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
...@@ -496,8 +498,14 @@ def evaluate( ...@@ -496,8 +498,14 @@ def evaluate(
### Calculate group metrics ### ### Calculate group metrics ###
if bool(results): if bool(results):
def process_group(results, task_dict, task_root=None, task_hierarchy=None, show_group_table=False):
def process_group(
results,
task_dict,
task_root=None,
task_hierarchy=None,
show_group_table=False,
):
if task_root is None: if task_root is None:
task_root = {} task_root = {}
...@@ -505,33 +513,42 @@ def evaluate( ...@@ -505,33 +513,42 @@ def evaluate(
task_hierarchy = {} task_hierarchy = {}
for group_or_task, group_or_task_info in task_dict.items(): for group_or_task, group_or_task_info in task_dict.items():
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(group_or_task)
else:
results, _task_hierarchy, show_group_table = process_group(results, group_or_task_info, group_or_task, task_hierarchy, show_group_table)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(task_hierarchy.get(group_or_task, []))
if isinstance(group_or_task, ConfigurableGroup): if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config group_config = group_or_task.config
group = group_or_task.group group_or_task = group_or_task.group
show_group_table = show_group_table | group_config["aggregate_metric"] show_group_table = (
show_group_table | group_config["aggregate_metric"]
)
if group_config["aggregate_metric"] is False: if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
elif isinstance(group_or_task, str):
results[group_or_task][" "] = " " results[group_or_task][" "] = " "
continue continue
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_hierarchy.setdefault(task_root, []).append(
group_or_task
)
else:
results, _task_hierarchy, show_group_table = process_group(
results,
group_or_task_info,
group_or_task,
task_hierarchy,
show_group_table,
)
if task_root:
task_hierarchy.setdefault(task_root, []).extend(
task_hierarchy.get(group_or_task, [])
)
task_list = _task_hierarchy[group_or_task] task_list = _task_hierarchy[group_or_task]
metric_list = list( metric_list = list(
{ {
key key
for task in task_list for task in task_list
for key in results[task].keys() for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"] if "_stderr" not in key
and key not in ["alias", "samples"]
} }
) )
for metric in metric_list: for metric in metric_list:
...@@ -555,7 +572,7 @@ def evaluate( ...@@ -555,7 +572,7 @@ def evaluate(
] ]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][ results[group_or_task][
metric metric
] = lm_eval.api.metrics.aggregate_subtask_metrics( ] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics, metrics,
...@@ -564,54 +581,69 @@ def evaluate( ...@@ -564,54 +581,69 @@ def evaluate(
) )
# TODO: calculate grouped metric using aggregation fn # TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs: if "N/A" in stderrs:
results[group][stderr] = "N/A" results[group_or_task][stderr] = "N/A"
else: else:
results[group][ results[group_or_task][
stderr stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) ] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group]["samples"] = sum(sizes) results[group_or_task]["samples"] = sum(sizes)
return results, task_hierarchy, show_group_table return results, task_hierarchy, show_group_table
results, task_hierarchy, show_group_table = process_group(results, task_dict) results, task_hierarchy, show_group_table = process_group(
results, task_dict
print(task_hierarchy) )
import sys; sys.exit()
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v["tasks"] for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg} def print_table(task_dict, results, task_depth=0):
groups_agg = {**groups_agg, **_groups_agg} task_agg = defaultdict(dict)
for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
from_configurable_group = False
task_agg[name] = results[name].copy()
if from_configurable_group:
if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias
else:
alias = name
else:
if "alias" in task_agg[name]:
alias = task_agg[name]["alias"]
else:
alias = name
for group_name, group_info in task_hierarchy.items(): task_agg[name]["alias"] = tab_string + alias
task_list = group_info["tasks"] if "samples" in task_agg[name]:
if task_list: task_agg[name].pop("samples")
num_fewshot[group_name] = num_fewshot[
task_list[0] if isinstance(task_or_group_obj, dict):
] # TODO: validate this task_depth += 1
task_agg = {
**task_agg,
**print_table(task_or_group_obj, results, task_depth),
}
task_depth -= 1
return task_agg
results_agg = print_table(task_dict, results)
import sys; sys.exit()
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**( # **(
{"groups": dict(groups_agg.items())} # {"groups": dict(groups_agg.items())}
if (bool(groups_agg) & show_group_table) # if (bool(groups_agg) & show_group_table)
else {} # else {}
), # ),
"group_subtasks": dict(reversed(task_hierarchy.items())), "group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
......
...@@ -2,11 +2,11 @@ import collections ...@@ -2,11 +2,11 @@ import collections
import math import math
import pathlib import pathlib
import sys import sys
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated from lm_eval.utils import eval_logger, positional_deprecated
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup
class TaskOutput: class TaskOutput:
""" """
...@@ -121,7 +121,6 @@ class TaskOutput: ...@@ -121,7 +121,6 @@ class TaskOutput:
def get_task_list(task_dict: dict) -> List[TaskOutput]: def get_task_list(task_dict: dict) -> List[TaskOutput]:
outputs = [] outputs = []
for task_name, task_obj in task_dict.items(): for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
_outputs = get_task_list(task_obj) _outputs = get_task_list(task_obj)
outputs.extend(_outputs) outputs.extend(_outputs)
......
...@@ -5,7 +5,8 @@ from functools import partial ...@@ -5,7 +5,8 @@ from functools import partial
from typing import Dict, List, Mapping, Optional, Union from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, ConfigurableGroup, GroupConfig, Task from lm_eval.api.task import ConfigurableGroup, ConfigurableTask, GroupConfig, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
......
...@@ -6,3 +6,4 @@ task: ...@@ -6,3 +6,4 @@ task:
- mmlu_humanities - mmlu_humanities
aggregate_metric: True aggregate_metric: True
weight_by_size: True weight_by_size: True
version: 1
...@@ -16,3 +16,4 @@ task: ...@@ -16,3 +16,4 @@ task:
# - mmlu_world_religions # - mmlu_world_religions
aggregate_metric: True aggregate_metric: True
weight_by_size: True weight_by_size: True
version: 1
...@@ -16,3 +16,4 @@ task: ...@@ -16,3 +16,4 @@ task:
# - mmlu_virology # - mmlu_virology
aggregate_metric: True aggregate_metric: True
weight_by_size: True weight_by_size: True
version: 1
...@@ -15,3 +15,4 @@ task: ...@@ -15,3 +15,4 @@ task:
# - mmlu_us_foreign_policy # - mmlu_us_foreign_policy
aggregate_metric: True aggregate_metric: True
weight_by_size: True weight_by_size: True
version: 1
...@@ -22,3 +22,4 @@ task: ...@@ -22,3 +22,4 @@ task:
# - mmlu_machine_learning # - mmlu_machine_learning
aggregate_metric: True aggregate_metric: True
weight_by_size: True weight_by_size: True
version: 1
...@@ -242,8 +242,11 @@ def make_table(result_dict, column: str = "results"): ...@@ -242,8 +242,11 @@ def make_table(result_dict, column: str = "results"):
values = [] values = []
for k, dic in result_dict[column].items(): for k, dic in result_dict[column].items():
version = result_dict["versions"].get(k, "N/A") version = result_dict["versions"].get(k, " N/A")
if k in result_dict["n-shot"]:
n = str(result_dict["n-shot"][k]) n = str(result_dict["n-shot"][k])
else:
n = " "
if "alias" in dic: if "alias" in dic:
k = dic.pop("alias") k = dic.pop("alias")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment