Commit 57b91fdb authored by Baber's avatar Baber
Browse files

refactor: enhance metric handling and aggregation logic

parent 911cae22
......@@ -19,6 +19,9 @@ class GenerateInput:
else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
)
def __getitem__(self, item: int):
return [self.prompt, self.gen_kwargs][item]
@dataclass
class GenerateOutput:
......@@ -54,3 +57,40 @@ class LoglikelihoodOutput:
def __iter__(self):
return iter((self.loglikelihood, self.is_greedy))
@dataclass
class MetricResult:
"""
Outputs for the metric function.
"""
doc_id: str | int | None
scores: list[dict[str, float]] | None
filter_key: str = None
metric_name: str = None
metadata: Optional[dict] = None
def __iter__(self):
if self.scores is None:
return iter([])
# Group values by metric key
grouped = {}
for score_dict in self.scores:
for key, value in score_dict.items():
if key not in grouped:
grouped[key] = []
grouped[key].append(value)
# Return iterator of (key, list[values]) pairs
return iter(grouped.items())
def get_metric_results(self, metric_key) -> list[float]:
if self.scores is None:
return []
return [
score_dict[metric_key]
for score_dict in self.scores
if metric_key in score_dict
]
......@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation,
is_higher_better,
)
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput, MetricResult
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -98,6 +98,7 @@ class TaskConfig(dict):
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
repeat_agg: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -1818,7 +1819,9 @@ class ConfigurableTask(Task):
)
]
all_metrics[filter_key].append(metrics)
all_metrics[filter_key].append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
)
return all_metrics
......
......@@ -637,11 +637,13 @@ def evaluate(
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
# Get the metrics for this document
doc_metrics = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
# doc_metrics = [
# task.process_results(doc, response)
# for req in requests
# for response in req.filtered_resps[filter_key]
# ]
# TODO: doc_metrics is flat list with floats and not clear if we have multiple emtircs
doc_metrics = [y for y in metrics[filter_key][0]]
target = task.doc_to_target(doc)
example = {
......@@ -672,18 +674,16 @@ def evaluate(
# Process all metrics returned from calculate_metrics
for filter_key in metrics:
for sample_metric in metrics[filter_key]:
for metric_key, value in sample_metric:
task_output.sample_metrics[(metric_key, filter_key)].append(
# we get a list of metric results
# [MetricResult(doc_id=0, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None),
# MetricResult(doc_id=1, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None)]
for m_samples in metrics[filter_key]:
# m_samples is a MetricResult object
# m_samples.scores is a list of dicts
for metric, value in m_samples:
task_output.sample_metrics[(metric, filter_key)].append(
value
)
# metrics is a list of dictionaries, each containing metric names and their values
# e.g., [{"accuracy": 0.9}, {"f1": 0.8}]
# We need to iterate through each dictionary and extract the metric names and values
# for x in metrics:
# for metric, value in x.items():
# task_output.sample_metrics[(metric, filter_key)].append(value)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
......
......@@ -111,7 +111,15 @@ class TaskOutput:
# TODO: Handle this better and allow other aggregate functions other than mean.
agg_fn = mean
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
# Handle multiple repeats: items is now list[list[float]]
if items and isinstance(items[0], list):
# Apply aggregation function to each repeat
self.agg_metrics[metric_key] = [
agg_fn(repeat) for repeat in zip(*items)
]
else:
# Backward compatibility: items is list[float]
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if isinstance(bootstrap_iters, int):
stderr_fn = stderr_for_metric(
......
......@@ -27,19 +27,17 @@ generation_kwargs:
- "<|im_end|>"
do_sample: false
temperature: 0.0
repeats: 1
repeats: 3
num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"
metadata:
version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment