Commit e30978c7 authored by Baber's avatar Baber
Browse files

add metric calulation method to configurable task

parent e9eb451e
...@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task): ...@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task):
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
) )
def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
):
"""Calculate metrics for all datapoints in the task.
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
Returns:
list: A list of metrics calculated for each document.
"""
all_metrics = []
# indices = samples.get(self.config.task, None) if samples is not None else None
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
all_metrics.extend(metrics)
return all_metrics
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood" OUTPUT_TYPE = "loglikelihood"
......
...@@ -596,56 +596,125 @@ def evaluate( ...@@ -596,56 +596,125 @@ def evaluate(
instances.sort(key=lambda x: x.idx) instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys(): for filter_key in task.instances[0].filtered_resps.keys():
indices = ( if hasattr(task, "calculate_metrics"):
samples.get(task_output.task_name, None) # Use the new method if it exists (ConfigurableTask)
if samples is not None metrics = task.calculate_metrics(
else None instances_by_doc_id=instances_by_doc_id,
) filter_key=filter_key,
doc_iterator = task.doc_iterator( samples=samples,
rank=RANK, rank=RANK,
limit=limit, limit=limit,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
samples=indices, )
)
for doc_id, doc in doc_iterator: # Add sample logging here too - similar to what's done in the else branch
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if log_samples: if log_samples:
target = task.doc_to_target(doc) indices = (
example = { samples.get(task_output.task_name, None)
"doc_id": doc_id_true, if samples is not None
"doc": doc, else None
"target": target, )
"arguments": [req.args for req in requests], doc_iterator = task.doc_iterator(
"resps": [req.resps for req in requests], rank=RANK,
"filtered_resps": [ limit=limit,
req.filtered_resps[filter_key] for req in requests world_size=WORLD_SIZE,
], samples=indices,
"filter": filter_key, )
"metrics": list({k for m in metrics for k in m.keys()}), for doc_id, doc in doc_iterator:
"doc_hash": hash_string( doc_id_true = indices[doc_id] if indices else doc_id
json.dumps( requests = instances_by_doc_id[doc_id]
requests[0].doc, if requests: # Make sure there are requests for this doc_id
indent=2, # Get the metrics for this document
default=handle_non_serializable, doc_metrics = [
ensure_ascii=False, task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": doc_metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update({"metrics": metrics})
task_output.logged_samples.append(example)
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(
value
) )
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(metrics)
task_output.logged_samples.append(example)
for metric, value in metrics.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
if WORLD_SIZE > 1: if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0 # if multigpu, then gather data across all ranks to rank 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment