Unverified Commit 14221c84 authored by Zafir Stojanovski's avatar Zafir Stojanovski Committed by GitHub
Browse files

`higher_is_better` tickers in output table (#1893)



* Higher is better tickers in output table

* add extra check for `higher_is_better` not being None already

* Update lm_eval/evaluator.py

* fixup format I messed up

* add comment (and retrigger tests)

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
Co-authored-by: default avatarhaileyschoelkopf <hailey@eleuther.ai>
parent 8f716817
...@@ -503,9 +503,14 @@ def evaluate( ...@@ -503,9 +503,14 @@ def evaluate(
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for task_output in eval_tasks: for task_output in eval_tasks:
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
results, samples, configs, versions, num_fewshot = consolidate_results( (
eval_tasks results,
) samples,
configs,
versions,
num_fewshot,
higher_is_better,
) = consolidate_results(eval_tasks)
### Calculate group metrics ### ### Calculate group metrics ###
if bool(results): if bool(results):
...@@ -516,6 +521,27 @@ def evaluate( ...@@ -516,6 +521,27 @@ def evaluate(
# or `task_name: []`. # or `task_name: []`.
# we only want to operate on groups here. # we only want to operate on groups here.
continue continue
# collect all higher_is_better values for metrics
# in the group's subtasks.
# TODO: clean this up ; unify with the below metric_list loop?
_higher_is_better = {}
for task in task_list:
for m, h in higher_is_better[task].items():
if m not in _higher_is_better.keys():
_higher_is_better[m] = h
if (
m in _higher_is_better
and _higher_is_better[m] is not None
and _higher_is_better[m] != h
):
eval_logger.warning(
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
)
_higher_is_better[m] = None
higher_is_better[group] = _higher_is_better
# collect all metric keys used by a subtask in the group.
metric_list = list( metric_list = list(
{ {
key key
...@@ -591,6 +617,7 @@ def evaluate( ...@@ -591,6 +617,7 @@ def evaluate(
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
"higher_is_better": dict(sorted(higher_is_better.items())),
"n-samples": { "n-samples": {
task_output.task_name: { task_output.task_name: {
"original": len(task_output.task.eval_docs), "original": len(task_output.task.eval_docs),
......
...@@ -253,6 +253,9 @@ def consolidate_results( ...@@ -253,6 +253,9 @@ def consolidate_results(
configs = collections.defaultdict(dict) configs = collections.defaultdict(dict)
# Tracks each task's version. # Tracks each task's version.
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
# Track `higher_is_better` for each metric
higher_is_better = collections.defaultdict(dict)
for task_output in eval_tasks: for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config): if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"] results[task_output.task_name]["alias"] = task_config["task_alias"]
...@@ -263,6 +266,7 @@ def consolidate_results( ...@@ -263,6 +266,7 @@ def consolidate_results(
configs[task_output.task_name] = task_output.task_config configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples samples[task_output.task_name] = task_output.logged_samples
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
for (metric, filter_key), items in task_output.sample_metrics.items(): for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}" metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[ results[task_output.task_name][metric_key] = task_output.agg_metrics[
...@@ -272,7 +276,7 @@ def consolidate_results( ...@@ -272,7 +276,7 @@ def consolidate_results(
results[task_output.task_name][ results[task_output.task_name][
f"{metric}_stderr,{filter_key}" f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] ] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot return results, samples, configs, versions, num_fewshot, higher_is_better
@positional_deprecated @positional_deprecated
......
...@@ -26,6 +26,11 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -26,6 +26,11 @@ eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47 SPACING = " " * 47
HIGHER_IS_BETTER_SYMBOLS = {
True: "↑",
False: "↓",
}
def hash_string(string: str) -> str: def hash_string(string: str) -> str:
return hashlib.sha256(string.encode("utf-8")).hexdigest() return hashlib.sha256(string.encode("utf-8")).hexdigest()
...@@ -257,6 +262,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): ...@@ -257,6 +262,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
"Filter", "Filter",
"n-shot", "n-shot",
"Metric", "Metric",
"",
"Value", "Value",
"", "",
"Stderr", "Stderr",
...@@ -277,6 +283,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): ...@@ -277,6 +283,7 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
dic = result_dict[column][k] dic = result_dict[column][k]
version = result_dict["versions"].get(k, "N/A") version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k]) n = str(result_dict["n-shot"][k])
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
if "alias" in dic: if "alias" in dic:
k = dic.pop("alias") k = dic.pop("alias")
...@@ -286,13 +293,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True): ...@@ -286,13 +293,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = True):
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
if se != "N/A": if se != "N/A":
se = "%.4f" % se se = "%.4f" % se
values.append([k, version, f, n, m, "%.4f" % v, "±", se]) values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se])
else: else:
values.append([k, version, f, n, m, "%.4f" % v, "", ""]) values.append([k, version, f, n, m, hib, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment