Commit 911cae22 authored by Baber's avatar Baber
Browse files

TODO!

parent 69e95b87
import abc
import ast
import collections
import logging
import random
import re
......@@ -1768,8 +1769,14 @@ class ConfigurableTask(Task):
)
def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
) -> list[list[dict]]:
self,
instances_by_doc_id,
filter_keys=None,
samples=None,
rank=1,
limit=None,
world_size=1,
) -> dict[str, list[dict]]:
"""Calculate metrics for all datapoints in the task.
Args:
......@@ -1783,28 +1790,35 @@ class ConfigurableTask(Task):
Returns:
list: A list of metrics calculated for each document.
"""
all_metrics = []
if filter_keys is None:
filter_keys = [x.name for x in self._filters]
if isinstance(filter_keys, str):
filter_keys = [filter_keys]
all_metrics = collections.defaultdict(list)
# indices = samples.get(self.config.task, None) if samples is not None else None
for filter_key in filter_keys:
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[list[dict]] = [
self.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
for response in (
req.filtered_resps[filter_key]
if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
)
]
# TODO: This turns metrics into a list of lists of dicts rather than flat list.
all_metrics.append(metrics)
all_metrics[filter_key].append(metrics)
return all_metrics
......
......@@ -38,7 +38,7 @@ from lm_eval.utils import (
if TYPE_CHECKING:
from lm_eval.api.model import LM
from lm_eval.api.task import Task
from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__)
......@@ -585,7 +585,7 @@ def evaluate(
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output, limit in zip(eval_tasks, limits):
task = task_output.task
task: ConfigurableTask = task_output.task
task.apply_filters()
### Collect values of metrics on all datapoints ###
......@@ -599,17 +599,25 @@ def evaluate(
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
if hasattr(task, "calculate_metrics"):
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
for filter_key in task.instances[0].filtered_resps.keys():
if hasattr(task, "calculate_metrics"):
# Use the new method if it exists (ConfigurableTask)
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
filter_key=filter_key,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
# metrics = task.calculate_metrics(
# instances_by_doc_id=instances_by_doc_id,
# filter_keys=filter_key,
# samples=samples,
# rank=RANK,
# limit=limit,
# world_size=WORLD_SIZE,
# )
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
......@@ -663,9 +671,19 @@ def evaluate(
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
for filter_key in metrics:
for sample_metric in metrics[filter_key]:
for metric_key, value in sample_metric:
task_output.sample_metrics[(metric_key, filter_key)].append(
value
)
# metrics is a list of dictionaries, each containing metric names and their values
# e.g., [{"accuracy": 0.9}, {"f1": 0.8}]
# We need to iterate through each dictionary and extract the metric names and values
# for x in metrics:
# for metric, value in x.items():
# task_output.sample_metrics[(metric, filter_key)].append(value)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment