Commit ba1d4483 authored by Baber's avatar Baber
Browse files

refactor: streamline metric calculations and enhance logging

parent 57b91fdb
......@@ -94,3 +94,9 @@ class MetricResult:
for score_dict in self.scores
if metric_key in score_dict
]
@property
def metric_keys(self) -> list[str]:
if self.scores is None:
return []
return list(self.scores[0].keys()) if self.scores else []
......@@ -887,7 +887,7 @@ class ConfigurableTask(Task):
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
# self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}")
......@@ -1771,13 +1771,13 @@ class ConfigurableTask(Task):
def calculate_metrics(
self,
instances_by_doc_id,
instances_by_doc_id=None,
filter_keys=None,
samples=None,
rank=1,
limit=None,
world_size=1,
) -> dict[str, list[dict]]:
) -> Optional[dict[str, list[dict]]]:
"""Calculate metrics for all datapoints in the task.
Args:
......@@ -1791,12 +1791,23 @@ class ConfigurableTask(Task):
Returns:
list: A list of metrics calculated for each document.
"""
if not self._instances:
return
from collections import defaultdict
if filter_keys is None:
filter_keys = [x.name for x in self._filters]
filter_keys = (
[x.name for x in self._filters]
if hasattr(self, "_filters")
else ["none"]
)
if isinstance(filter_keys, str):
filter_keys = [filter_keys]
if not instances_by_doc_id:
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
all_metrics = collections.defaultdict(list)
# indices = samples.get(self.config.task, None) if samples is not None else None
for filter_key in filter_keys:
doc_iterator = self.doc_iterator(
rank=rank,
......@@ -1808,7 +1819,6 @@ class ConfigurableTask(Task):
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
......@@ -1818,7 +1828,6 @@ class ConfigurableTask(Task):
else [req.filtered_resps[filter_key]]
)
]
all_metrics[filter_key].append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
)
......
......@@ -609,16 +609,6 @@ def evaluate(
)
for filter_key in task.instances[0].filtered_resps.keys():
if hasattr(task, "calculate_metrics"):
# Use the new method if it exists (ConfigurableTask)
# metrics = task.calculate_metrics(
# instances_by_doc_id=instances_by_doc_id,
# filter_keys=filter_key,
# samples=samples,
# rank=RANK,
# limit=limit,
# world_size=WORLD_SIZE,
# )
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
indices = (
......@@ -636,15 +626,7 @@ def evaluate(
doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
# Get the metrics for this document
# doc_metrics = [
# task.process_results(doc, response)
# for req in requests
# for response in req.filtered_resps[filter_key]
# ]
# TODO: doc_metrics is flat list with floats and not clear if we have multiple emtircs
doc_metrics = [y for y in metrics[filter_key][0]]
doc_metrics = metrics[filter_key][doc_id_true].metric_keys
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
......@@ -670,16 +652,18 @@ def evaluate(
),
"target_hash": hash_string(str(target)),
}
example.update(
{
metrics[filter_key][doc_id_true].metric_keys[
0
]: metrics[filter_key][doc_id_true]
}
)
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for filter_key in metrics:
# we get a list of metric results
# [MetricResult(doc_id=0, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None),
# MetricResult(doc_id=1, scores=[{'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}, {'exact_match': np.float64(0.0)}], filter_key='strict-match', metric_name=None, metadata=None)]
for m_samples in metrics[filter_key]:
# m_samples is a MetricResult object
# m_samples.scores is a list of dicts
for metric, value in m_samples:
task_output.sample_metrics[(metric, filter_key)].append(
value
......
......@@ -128,9 +128,12 @@ class TaskOutput:
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
# TODO: what's the best way to calculate repeat stderr
# maybe mean/sample then bootstrap?
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = [
(stderr_fn(item) if (stderr_fn and len(item) > 1) else "N/A")
for item in zip(*items)
][0]
else:
raise ValueError(
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
......
......@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
return map(lambda r: r, resps)
@register_filter("take_first_k")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment