"vscode:/vscode.git/clone" did not exist on "c3fe0550a70a807ffef5c0c49573624abd52d813"
Commit 911cae22 authored by Baber's avatar Baber
Browse files

TODO!

parent 69e95b87
import abc import abc
import ast import ast
import collections
import logging import logging
import random import random
import re import re
...@@ -1768,8 +1769,14 @@ class ConfigurableTask(Task): ...@@ -1768,8 +1769,14 @@ class ConfigurableTask(Task):
) )
def calculate_metrics( def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size self,
) -> list[list[dict]]: instances_by_doc_id,
filter_keys=None,
samples=None,
rank=1,
limit=None,
world_size=1,
) -> dict[str, list[dict]]:
"""Calculate metrics for all datapoints in the task. """Calculate metrics for all datapoints in the task.
Args: Args:
...@@ -1783,28 +1790,35 @@ class ConfigurableTask(Task): ...@@ -1783,28 +1790,35 @@ class ConfigurableTask(Task):
Returns: Returns:
list: A list of metrics calculated for each document. list: A list of metrics calculated for each document.
""" """
all_metrics = [] if filter_keys is None:
filter_keys = [x.name for x in self._filters]
if isinstance(filter_keys, str):
filter_keys = [filter_keys]
all_metrics = collections.defaultdict(list)
# indices = samples.get(self.config.task, None) if samples is not None else None # indices = samples.get(self.config.task, None) if samples is not None else None
for filter_key in filter_keys:
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
doc_iterator = self.doc_iterator( for doc_id, doc in doc_iterator:
rank=rank, # doc_id_true = indices[doc_id] if indices else doc_id
limit=limit, requests = instances_by_doc_id[doc_id]
world_size=world_size,
# samples=indices, metrics = [
) self.process_results(doc, response)
for req in requests
for doc_id, doc in doc_iterator: for response in (
# doc_id_true = indices[doc_id] if indices else doc_id req.filtered_resps[filter_key]
requests = instances_by_doc_id[doc_id] if isinstance(req.filtered_resps[filter_key], list)
else [req.filtered_resps[filter_key]]
metrics: list[list[dict]] = [ )
self.process_results(doc, response) ]
for req in requests
for response in req.filtered_resps[filter_key]
]
# TODO: This turns metrics into a list of lists of dicts rather than flat list. all_metrics[filter_key].append(metrics)
all_metrics.append(metrics)
return all_metrics return all_metrics
......
...@@ -38,7 +38,7 @@ from lm_eval.utils import ( ...@@ -38,7 +38,7 @@ from lm_eval.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.task import Task from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -585,7 +585,7 @@ def evaluate( ...@@ -585,7 +585,7 @@ def evaluate(
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output, limit in zip(eval_tasks, limits): for task_output, limit in zip(eval_tasks, limits):
task = task_output.task task: ConfigurableTask = task_output.task
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
...@@ -599,17 +599,25 @@ def evaluate( ...@@ -599,17 +599,25 @@ def evaluate(
for instances in instances_by_doc_id.values(): for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx) instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
if hasattr(task, "calculate_metrics"):
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
for filter_key in task.instances[0].filtered_resps.keys(): for filter_key in task.instances[0].filtered_resps.keys():
if hasattr(task, "calculate_metrics"): if hasattr(task, "calculate_metrics"):
# Use the new method if it exists (ConfigurableTask) # Use the new method if it exists (ConfigurableTask)
metrics = task.calculate_metrics( # metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id, # instances_by_doc_id=instances_by_doc_id,
filter_key=filter_key, # filter_keys=filter_key,
samples=samples, # samples=samples,
rank=RANK, # rank=RANK,
limit=limit, # limit=limit,
world_size=WORLD_SIZE, # world_size=WORLD_SIZE,
) # )
# Add sample logging here too - similar to what's done in the else branch # Add sample logging here too - similar to what's done in the else branch
if log_samples: if log_samples:
...@@ -663,9 +671,19 @@ def evaluate( ...@@ -663,9 +671,19 @@ def evaluate(
task_output.logged_samples.append(example) task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics # Process all metrics returned from calculate_metrics
for x in metrics: for filter_key in metrics:
for metric, value in x.items(): for sample_metric in metrics[filter_key]:
task_output.sample_metrics[(metric, filter_key)].append(value) for metric_key, value in sample_metric:
task_output.sample_metrics[(metric_key, filter_key)].append(
value
)
# metrics is a list of dictionaries, each containing metric names and their values
# e.g., [{"accuracy": 0.9}, {"f1": 0.8}]
# We need to iterate through each dictionary and extract the metric names and values
# for x in metrics:
# for metric, value in x.items():
# task_output.sample_metrics[(metric, filter_key)].append(value)
else: else:
# Fall back to the original approach for non-ConfigurableTask instances # Fall back to the original approach for non-ConfigurableTask instances
indices = ( indices = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment