Commit 2f3d4272 authored by lintangsutawika's avatar lintangsutawika
Browse files

use task id to differentiate tasks

parent 5a3a9573
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api import metrics from lm_eval.api import metrics
from lm_eval.tasks import ConfigurableGroup from lm_eval.tasks import ConfigurableGroup, ConfigurableTask
from lm_eval.utils import eval_logger, positional_deprecated from lm_eval.utils import eval_logger, positional_deprecated
...@@ -40,6 +40,7 @@ class TaskOutput: ...@@ -40,6 +40,7 @@ class TaskOutput:
self, self,
task=None, task=None,
task_name=None, task_name=None,
task_id=None,
task_config=None, task_config=None,
version=None, version=None,
group_name=None, group_name=None,
...@@ -51,6 +52,7 @@ class TaskOutput: ...@@ -51,6 +52,7 @@ class TaskOutput:
self.task = task self.task = task
self.task_config = task_config self.task_config = task_config
self.task_name = task_name self.task_name = task_name
self.task_id = task_id
self.group_name = group_name self.group_name = group_name
self.version = version self.version = version
self.n_shot = n_shot self.n_shot = n_shot
...@@ -76,6 +78,7 @@ class TaskOutput: ...@@ -76,6 +78,7 @@ class TaskOutput:
task=task, task_name=task_name, is_group=is_group, group_name=group_name task=task, task_name=task_name, is_group=is_group, group_name=group_name
) )
version = task.VERSION version = task.VERSION
task_id = task.task_id
task_config = dict(task.dump_config()) task_config = dict(task.dump_config())
if (n_shot := task_config.get("num_fewshot")) == 0: if (n_shot := task_config.get("num_fewshot")) == 0:
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0) n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
...@@ -84,6 +87,7 @@ class TaskOutput: ...@@ -84,6 +87,7 @@ class TaskOutput:
return cls( return cls(
task=task, task=task,
task_name=task_name, task_name=task_name,
task_id=task_id,
task_config=task_config, task_config=task_config,
group_name=group_name, group_name=group_name,
version=version, version=version,
...@@ -113,9 +117,10 @@ class TaskOutput: ...@@ -113,9 +117,10 @@ class TaskOutput:
return ( return (
f"TaskOutput(task_name={self.task_name}, " f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, " f"group_name={self.group_name}, "
f"version={self.version}," f"version={self.version}, "
f"n_shot={self.n_shot}" f"n_shot={self.n_shot}, "
f"task_alias={self.task_alias}, group_alias={self.group_alias})" f"task_alias={self.task_alias}, "
f"group_alias={self.group_alias})"
) )
...@@ -176,10 +181,13 @@ def prepare_print_tasks( ...@@ -176,10 +181,13 @@ def prepare_print_tasks(
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, ConfigurableGroup):
name = task_or_group_name.group # name = task_or_group_name.group
name = task_or_group_name.task_id
from_configurable_group = True from_configurable_group = True
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
name = task_or_group_name name = task_or_group_name
if isinstance(task_or_group_obj, ConfigurableTask):
name = task_or_group_obj.task_id
from_configurable_group = False from_configurable_group = False
task_agg[name] = results[name].copy() task_agg[name] = results[name].copy()
...@@ -187,7 +195,7 @@ def prepare_print_tasks( ...@@ -187,7 +195,7 @@ def prepare_print_tasks(
if task_or_group_name.group_alias is not None: if task_or_group_name.group_alias is not None:
alias = task_or_group_name.group_alias alias = task_or_group_name.group_alias
else: else:
alias = name alias = task_or_group_name.group
else: else:
if "alias" in task_agg[name]: if "alias" in task_agg[name]:
alias = task_agg[name]["alias"] alias = task_agg[name]["alias"]
...@@ -255,21 +263,23 @@ def consolidate_results( ...@@ -255,21 +263,23 @@ def consolidate_results(
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
for task_output in eval_tasks: for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config): if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"] results[task_output.task_id]["alias"] = task_config["task_alias"]
else:
results[task_output.task_id]["alias"] = task_output.task_name
if group_alias := task_output.group_alias: if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name): if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias results[group_name]["alias"] = group_alias
num_fewshot[task_output.task_name] = task_output.n_shot num_fewshot[task_output.task_id] = task_output.n_shot
configs[task_output.task_name] = task_output.task_config configs[task_output.task_id] = task_output.task_config
versions[task_output.task_name] = task_output.version versions[task_output.task_id] = task_output.version
samples[task_output.task_name] = task_output.logged_samples samples[task_output.task_id] = task_output.logged_samples
for (metric, filter_key), items in task_output.sample_metrics.items(): for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}" metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[ results[task_output.task_id][metric_key] = task_output.agg_metrics[
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_id]["samples"] = task_output.sample_len
results[task_output.task_name][ results[task_output.task_id][
f"{metric}_stderr,{filter_key}" f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] ] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot return results, samples, configs, versions, num_fewshot
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment