Commit 3b7e6cc6 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

giving this a try

parent d13b1f56
......@@ -15,6 +15,7 @@ import lm_eval.api.task
import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_group_results,
consolidate_results,
get_sample_size,
get_subtask_list,
......@@ -26,8 +27,6 @@ from lm_eval.evaluator_utils import (
from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import (
ConfigurableGroup,
ConfigurableTask,
TaskManager,
get_task_dict,
)
......@@ -227,13 +226,15 @@ def simple_evaluate(
task_dict = get_task_dict(tasks, task_manager)
def _adjust_config(task_dict, predict_only):
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj, predict_only)},
**{task_name: _adjust_config(task_obj)},
}
else:
......@@ -278,7 +279,7 @@ def simple_evaluate(
return adjusted_task_dict
task_dict = _adjust_config(task_dict, predict_only)
task_dict = _adjust_config(task_dict)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -568,138 +569,7 @@ def evaluate(
### Calculate group metrics ###
if bool(results):
def process_group(
results,
versions,
task_dict,
task_root=None,
show_group_table=False,
task_aggregation_list=None,
):
if task_root is None:
task_root = {}
if task_aggregation_list is None:
task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.task_id
else:
group_config = None
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_id
)
else:
(
results,
versions,
show_group_table,
_task_aggregation_list,
) = process_group(
results,
versions,
group_or_task_info,
group_or_task,
show_group_table,
task_aggregation_list,
)
if task_root:
task_aggregation_list.setdefault(task_root, []).extend(
task_aggregation_list.get(group_or_task, [])
)
if (group_config is None) or (
group_config["aggregate_metric_list"] is None
):
results[group_or_task][" "] = " "
continue
if "aggregate_metric_list" in group_config:
agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"]
)
task_list = _task_aggregation_list[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key
and key not in ["task", "alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
for metric_config in agg_metric_list:
for filter_name in metric_config["filter_list"]:
if metric != ",".join(
[metric_config["metric"], filter_name]
):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = lm_eval.api.metrics.aggregate_subtask_metrics
else:
raise ValueError(
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
)
results[group_or_task][metric] = aggregate_fn(
metrics,
sizes,
metric_config["weight_by_size"],
)
# TODO: calculate groups' metrics using arbitrary agg fns
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
# TODO: put in a warning, if we are using non-micro avg mean or another aggregation fn
results[group_or_task][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None)
if group_metadata is not None:
versions[group_or_task] = group_metadata.get(
"version", None
)
# print(results)
return results, versions, show_group_table, task_aggregation_list
results, versions, show_group_table, *_ = process_group(
results, versions, show_group_table, *_ = consolidate_group_results(
results, versions, task_dict
)
......
......@@ -4,8 +4,12 @@ import pathlib
import sys
from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import ConfigurableGroup, ConfigurableTask
from lm_eval.utils import eval_logger, positional_deprecated
......@@ -103,7 +107,7 @@ class TaskOutput:
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if isinstance(bootstrap_iters, int):
stderr_fn = metrics.stderr_for_metric(
stderr_fn = stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
......@@ -128,19 +132,13 @@ class TaskOutput:
)
def get_task_list(task_dict: dict, task_root=None) -> List[TaskOutput]:
def get_task_list(task_dict: dict) -> List[TaskOutput]:
outputs = []
for task_name, task_obj in task_dict.items():
if isinstance(task_name, str):
prefix_name = task_name
else:
prefix_name = task_name.task_id
if isinstance(task_obj, dict):
_outputs = get_task_list(task_obj, task_root=prefix_name)
_outputs = get_task_list(task_obj)
outputs.extend(_outputs)
else:
task_obj.task_id = f"{task_root}:{task_obj.task_id}"
task_output = TaskOutput.from_taskdict(task_name, task_obj)
outputs.append(task_output)
......@@ -152,7 +150,7 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
# group_name = group_obj.group_name
group_name = group_obj.task_id
group_name = group_obj.group_name
else:
group_name = group_obj
if isinstance(task_obj, dict):
......@@ -172,10 +170,10 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
else:
if isinstance(task_obj, ConfigurableGroup):
# group_or_task_name = task_obj.group_name
group_or_task_name = task_obj.task_id
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, ConfigurableTask):
# group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_id
group_or_task_name = task_obj.task_name
if task_root is None:
subtask_list.setdefault((group_or_task_name, depth), [])
......@@ -239,13 +237,13 @@ def prepare_print_tasks(
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup):
# string_name = task_or_group_name.group_name
name = task_or_group_name.task_id
name = task_or_group_name.group_name
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, ConfigurableTask):
# string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_id
name = task_or_group_obj.task_name
from_configurable_group = False
task_agg[name] = results[name].copy()
......@@ -325,29 +323,168 @@ def consolidate_results(
for task_output in eval_tasks:
# results[task_output.task_id]["task"] = task_output.task_name
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_id]["alias"] = task_config["task_alias"]
results[task_output.task_name]["alias"] = task_config["task_alias"]
else:
results[task_output.task_id]["alias"] = task_output.task_name
results[task_output.task_name]["alias"] = task_output.task_name
if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias
num_fewshot[task_output.task_id] = task_output.n_shot
configs[task_output.task_id] = task_output.task_config
versions[task_output.task_id] = task_output.version
samples[task_output.task_id] = task_output.logged_samples
higher_is_better[task_output.task_id] = task_output.task.higher_is_better()
num_fewshot[task_output.task_name] = task_output.n_shot
configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}"
results[task_output.task_id][metric_key] = task_output.agg_metrics[
results[task_output.task_name][metric_key] = task_output.agg_metrics[
metric_key
]
results[task_output.task_id]["samples"] = task_output.sample_len
results[task_output.task_id][
results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot, higher_is_better
def consolidate_group_results(
results,
versions,
task_dict,
task_root=None,
show_group_table=False,
task_aggregation_list=None,
) -> Tuple[dict, dict, bool, Union[None,]]:
"""
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
- results: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
- versions: A defaultdict with task names (and, after this function is called, group names of
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
if task_root is None:
task_root = {}
if task_aggregation_list is None:
task_aggregation_list = {}
for group_or_task, group_or_task_info in task_dict.items():
# Convert to string
if isinstance(group_or_task, ConfigurableGroup):
group_config = group_or_task.config
group_or_task = group_or_task.group_name
else:
group_config = None
if isinstance(group_or_task_info, ConfigurableTask):
if task_root:
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_name
)
else:
(
results,
versions,
show_group_table,
_task_aggregation_list,
) = consolidate_group_results(
results,
versions,
group_or_task_info,
group_or_task,
show_group_table,
task_aggregation_list,
)
if task_root:
task_aggregation_list.setdefault(task_root, []).extend(
task_aggregation_list.get(group_or_task, [])
)
if (group_config is None) or (
group_config["aggregate_metric_list"] is None
):
results[group_or_task][" "] = " "
continue
if "aggregate_metric_list" in group_config:
agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"]
)
task_list = _task_aggregation_list[group_or_task]
metric_list = list(
{
key
for task in task_list
for key in results[task].keys()
if "_stderr" not in key and key not in ["task", "alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [
results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy?
stderrs = [
results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
for metric_config in agg_metric_list:
for filter_name in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter_name]):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = aggregate_subtask_metrics
else:
raise ValueError(
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
)
results[group_or_task][metric] = aggregate_fn(
metrics,
sizes,
metric_config["weight_by_size"],
)
# TODO: calculate groups' metrics using arbitrary agg fns
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
results[group_or_task][stderr] = pooled_sample_stderr(
stderrs, sizes
)
results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None)
if group_metadata is not None:
versions[group_or_task] = group_metadata.get("version", None)
# print(results)
return results, versions, show_group_table, task_aggregation_list
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
......
......@@ -6,6 +6,7 @@ from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.task import ConfigurableGroup, ConfigurableTask, GroupConfig, Task
from lm_eval.evaluator_utils import get_subtask_list
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
......@@ -154,8 +155,9 @@ class TaskManager:
else:
task_object = ConfigurableTask(config=config)
if task != task_object.task_id:
task_object.task_id = task
# if task != task_object.task_id:
# assert False
# task_object.task_id = task
return {task: task_object}
......@@ -364,7 +366,7 @@ class TaskManager:
if attr in config:
if attr == "group" and print_info:
self.logger.info(
"`group` and `group_alias` will no longer be used in the next release of lm-eval. "
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
"`tag` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations."
......@@ -413,6 +415,33 @@ def get_task_name_from_object(task_object):
)
def _check_duplicates(task_dict: dict) -> List[str]:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
"oversubscribed" to several disjoint groups.
"""
subtask_names = []
for key, value in task_dict.items():
subtask_names.extend(value)
duplicate_tasks = {
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
}
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [
group
for group in task_dict.keys()
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
]
if len(duplicate_tasks) > 0:
raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
)
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
......@@ -430,6 +459,7 @@ def get_task_dict(
:return
Dictionary of task objects
"""
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
......@@ -476,8 +506,16 @@ def get_task_dict(
):
raise ValueError
return {
final_task_dict = {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment