Commit e30978c7 authored by Baber's avatar Baber
Browse files

add metric calulation method to configurable task

parent e9eb451e
...@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task): ...@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task):
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
) )
def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
):
"""Calculate metrics for all datapoints in the task.
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
Returns:
list: A list of metrics calculated for each document.
"""
all_metrics = []
# indices = samples.get(self.config.task, None) if samples is not None else None
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
all_metrics.extend(metrics)
return all_metrics
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood" OUTPUT_TYPE = "loglikelihood"
......
...@@ -596,6 +596,72 @@ def evaluate( ...@@ -596,6 +596,72 @@ def evaluate(
instances.sort(key=lambda x: x.idx) instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys(): for filter_key in task.instances[0].filtered_resps.keys():
if hasattr(task, "calculate_metrics"):
# Use the new method if it exists (ConfigurableTask)
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
filter_key=filter_key,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
# Get the metrics for this document
doc_metrics = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": doc_metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = ( indices = (
samples.get(task_output.task_name, None) samples.get(task_output.task_name, None)
if samples is not None if samples is not None
...@@ -630,7 +696,7 @@ def evaluate( ...@@ -630,7 +696,7 @@ def evaluate(
req.filtered_resps[filter_key] for req in requests req.filtered_resps[filter_key] for req in requests
], ],
"filter": filter_key, "filter": filter_key,
"metrics": list({k for m in metrics for k in m.keys()}), "metrics": metrics,
"doc_hash": hash_string( "doc_hash": hash_string(
json.dumps( json.dumps(
requests[0].doc, requests[0].doc,
...@@ -642,10 +708,13 @@ def evaluate( ...@@ -642,10 +708,13 @@ def evaluate(
"prompt_hash": hash_string(requests[0].arguments[0]), "prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)), "target_hash": hash_string(str(target)),
} }
example.update(metrics) example.update({"metrics": metrics})
task_output.logged_samples.append(example) task_output.logged_samples.append(example)
for metric, value in metrics.items(): for x in metrics:
task_output.sample_metrics[(metric, filter_key)].append(value) for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(
value
)
if WORLD_SIZE > 1: if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0 # if multigpu, then gather data across all ranks to rank 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment